Merge remote-tracking branch 'upstream/main' into refactor_dp

This commit is contained in:
Alexander Soare 2024-04-11 17:52:10 +01:00
commit 94cc22da9e
29 changed files with 545 additions and 603 deletions

2
.github/poetry/cpu/poetry.lock generated vendored
View File

@ -940,7 +940,7 @@ mujoco = "^2.3.7"
type = "git" type = "git"
url = "git@github.com:huggingface/gym-xarm.git" url = "git@github.com:huggingface/gym-xarm.git"
reference = "HEAD" reference = "HEAD"
resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d" resolved_reference = "6a88f7d63833705dfbec4b997bf36cac6b4a448c"
[[package]] [[package]]
name = "gymnasium" name = "gymnasium"

View File

@ -142,6 +142,7 @@ jobs:
wandb.enable=False \ wandb.enable=False \
offline_steps=2 \ offline_steps=2 \
online_steps=0 \ online_steps=0 \
eval_episodes=1 \
device=cpu \ device=cpu \
save_model=true \ save_model=true \
save_freq=2 \ save_freq=2 \
@ -159,17 +160,6 @@ jobs:
device=cpu \ device=cpu \
policy.pretrained_model_path=tests/outputs/act/models/2.pt policy.pretrained_model_path=tests/outputs/act/models/2.pt
# TODO(aliberts): This takes ~2mn to run, needs to be improved
# - name: Test eval ACT on ALOHA end-to-end (policy is None)
# run: |
# source .venv/bin/activate
# python lerobot/scripts/eval.py \
# --config lerobot/configs/default.yaml \
# policy=act \
# env=aloha \
# eval_episodes=1 \
# device=cpu
- name: Test train Diffusion on PushT end-to-end - name: Test train Diffusion on PushT end-to-end
run: | run: |
source .venv/bin/activate source .venv/bin/activate
@ -179,9 +169,11 @@ jobs:
wandb.enable=False \ wandb.enable=False \
offline_steps=2 \ offline_steps=2 \
online_steps=0 \ online_steps=0 \
eval_episodes=1 \
device=cpu \ device=cpu \
save_model=true \ save_model=true \
save_freq=2 \ save_freq=2 \
policy.batch_size=2 \
hydra.run.dir=tests/outputs/diffusion/ hydra.run.dir=tests/outputs/diffusion/
- name: Test eval Diffusion on PushT end-to-end - name: Test eval Diffusion on PushT end-to-end
@ -194,16 +186,6 @@ jobs:
device=cpu \ device=cpu \
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
- name: Test eval Diffusion on PushT end-to-end (policy is None)
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \
policy=diffusion \
env=pusht \
eval_episodes=1 \
device=cpu
- name: Test train TDMPC on Simxarm end-to-end - name: Test train TDMPC on Simxarm end-to-end
run: | run: |
source .venv/bin/activate source .venv/bin/activate
@ -213,9 +195,11 @@ jobs:
wandb.enable=False \ wandb.enable=False \
offline_steps=1 \ offline_steps=1 \
online_steps=1 \ online_steps=1 \
eval_episodes=1 \
device=cpu \ device=cpu \
save_model=true \ save_model=true \
save_freq=2 \ save_freq=2 \
policy.batch_size=2 \
hydra.run.dir=tests/outputs/tdmpc/ hydra.run.dir=tests/outputs/tdmpc/
- name: Test eval TDMPC on Simxarm end-to-end - name: Test eval TDMPC on Simxarm end-to-end
@ -227,13 +211,3 @@ jobs:
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
- name: Test eval TDPMC on Simxarm end-to-end (policy is None)
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \
policy=tdmpc \
env=xarm \
eval_episodes=1 \
device=cpu

3
.gitignore vendored
View File

@ -11,6 +11,9 @@ rl
nautilus/*.yaml nautilus/*.yaml
*.key *.key
# Slurm
sbatch*.sh
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]

View File

@ -120,34 +120,32 @@ wandb login
You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities: You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities:
```python ```python
""" Copy pasted from `examples/1_visualize_dataset.py` """ """ Copy pasted from `examples/1_visualize_dataset.py` """
import os
from pathlib import Path
import lerobot import lerobot
from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.datasets.aloha import AlohaDataset
from torchrl.data.replay_buffers import SamplerWithoutReplacement
from lerobot.scripts.visualize_dataset import render_dataset from lerobot.scripts.visualize_dataset import render_dataset
print(lerobot.available_datasets) print(lerobot.available_datasets)
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium'] # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
# we use this sampler to sample 1 frame after the other # TODO(rcadene): remove DATA_DIR
sampler = SamplerWithoutReplacement(shuffle=False) dataset = AlohaDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler)
video_paths = render_dataset( video_paths = render_dataset(
dataset, dataset,
out_dir="outputs/visualize_dataset/example", out_dir="outputs/visualize_dataset/example",
max_num_samples=300, max_num_episodes=1,
fps=50,
) )
print(video_paths) print(video_paths)
# >>> ['outputs/visualize_dataset/example/episode_0.mp4'] # ['outputs/visualize_dataset/example/episode_0.mp4']
``` ```
Or you can achieve the same result by executing our script from the command line: Or you can achieve the same result by executing our script from the command line:
```bash ```bash
python lerobot/scripts/visualize_dataset.py \ python lerobot/scripts/visualize_dataset.py \
env=aloha \ env=pusht \
task=sim_sim_transfer_cube_human \
hydra.run.dir=outputs/visualize_dataset/example hydra.run.dir=outputs/visualize_dataset/example
# >>> ['outputs/visualize_dataset/example/episode_0.mp4'] # >>> ['outputs/visualize_dataset/example/episode_0.mp4']
``` ```

View File

@ -1,24 +1,20 @@
import os import os
from pathlib import Path
from torchrl.data.replay_buffers import SamplerWithoutReplacement
import lerobot import lerobot
from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.datasets.pusht import PushtDataset
from lerobot.scripts.visualize_dataset import render_dataset from lerobot.scripts.visualize_dataset import render_dataset
print(lerobot.available_datasets) print(lerobot.available_datasets)
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium'] # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
# we use this sampler to sample 1 frame after the other # TODO(rcadene): remove DATA_DIR
sampler = SamplerWithoutReplacement(shuffle=False) dataset = PushtDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler, root=os.environ.get("DATA_DIR"))
video_paths = render_dataset( video_paths = render_dataset(
dataset, dataset,
out_dir="outputs/visualize_dataset/example", out_dir="outputs/visualize_dataset/example",
max_num_samples=300, max_num_episodes=1,
fps=50,
) )
print(video_paths) print(video_paths)
# ['outputs/visualize_dataset/example/episode_0.mp4'] # ['outputs/visualize_dataset/example/episode_0.mp4']

View File

@ -9,9 +9,8 @@ from pathlib import Path
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import trange
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.diffusion.policy import DiffusionPolicy from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from lerobot.common.utils import init_hydra_config from lerobot.common.utils import init_hydra_config
@ -37,19 +36,33 @@ policy = DiffusionPolicy(
cfg_obs_encoder=cfg.obs_encoder, cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer, cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema, cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps,
**cfg.policy, **cfg.policy,
) )
policy.train() policy.train()
offline_buffer = make_offline_buffer(cfg) dataset = make_dataset(cfg)
# create dataloader for offline training
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
shuffle=True,
pin_memory=cfg.device != "cpu",
drop_last=True,
)
for step, batch in enumerate(dataloader):
info = policy(batch, step)
if step % cfg.log_freq == 0:
num_samples = (step + 1) * cfg.policy.batch_size
loss = info["loss"]
update_s = info["update_s"]
print(f"step:{step} samples:{num_samples} loss:{loss:.3f} update_time:{update_s:.3f}(seconds)")
for offline_step in trange(cfg.offline_steps):
train_info = policy.update(offline_buffer, offline_step)
if offline_step % cfg.log_freq == 0:
print(train_info)
# Save the policy, configuration, and normalization stats for later use. # Save the policy, configuration, and normalization stats for later use.
policy.save(output_directory / "model.pt") policy.save(output_directory / "model.pt")
OmegaConf.save(cfg, output_directory / "config.yaml") OmegaConf.save(cfg, output_directory / "config.yaml")
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth") torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")

View File

@ -12,14 +12,11 @@ Example:
print(lerobot.available_policies) print(lerobot.available_policies)
``` ```
Note: When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - Set the required class attributes: `available_datasets`.
1. set the required class attributes: - Set the required class attributes: `name`.
- for classes inheriting from `AbstractDataset`: `available_datasets` - Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - Update variables in `tests/test_available.py` by importing your new class
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
""" """
from lerobot.__version__ import __version__ # noqa: F401 from lerobot.__version__ import __version__ # noqa: F401
@ -32,11 +29,11 @@ available_envs = [
available_tasks_per_env = { available_tasks_per_env = {
"aloha": [ "aloha": [
"sim_insertion", "AlohaInsertion-v0",
"sim_transfer_cube", "AlohaTransferCube-v0",
], ],
"pusht": ["pusht"], "pusht": ["PushT-v0"],
"xarm": ["lift"], "xarm": ["XarmLift-v0"],
} }
available_datasets_per_env = { available_datasets_per_env = {

View File

@ -105,7 +105,7 @@ class AlohaDataset(torch.utils.data.Dataset):
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
return len(self.data_dict["index"]) return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:

View File

@ -1,10 +1,11 @@
import logging
import os import os
from pathlib import Path from pathlib import Path
import torch import torch
from torchvision.transforms import v2 from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_or_load_stats from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod from lerobot.common.transforms import NormalizeTransform, Prod
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and # DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
@ -40,7 +41,8 @@ def make_dataset(
if normalize: if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec # min_max_from_spec
# stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path) # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
stats = {} stats = {}
@ -51,21 +53,27 @@ def make_dataset(
stats["action"] = {} stats["action"] = {}
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
else: elif stats_path is None:
# instantiate a one frame dataset with light transform # instantiate a one frame dataset with light transform
stats_dataset = clsfunc( stats_dataset = clsfunc(
dataset_id=cfg.dataset_id, dataset_id=cfg.dataset_id,
root=DATA_DIR, root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
) )
stats = compute_or_load_stats(stats_dataset)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std # load stats if the file exists already or compute stats and save it
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" precomputed_stats_path = stats_dataset.data_dir / "stats.pth"
if precomputed_stats_path.exists():
stats = torch.load(precomputed_stats_path)
else:
logging.info(f"compute_stats and save to {precomputed_stats_path}")
stats = compute_stats(stats_dataset)
torch.save(stats, stats_path)
else:
stats = torch.load(stats_path)
transforms = v2.Compose( transforms = v2.Compose(
[ [
# TODO(rcadene): we need to do something about image_keys
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
NormalizeTransform( NormalizeTransform(
stats, stats,

View File

@ -2,11 +2,8 @@ from pathlib import Path
import einops import einops
import numpy as np import numpy as np
import pygame
import pymunk
import torch import torch
import tqdm import tqdm
from gym_pusht.envs.pusht import pymunk_to_shapely
from lerobot.common.datasets._diffusion_policy_replay_buffer import ( from lerobot.common.datasets._diffusion_policy_replay_buffer import (
ReplayBuffer as DiffusionPolicyReplayBuffer, ReplayBuffer as DiffusionPolicyReplayBuffer,
@ -20,64 +17,6 @@ PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr") PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
def get_goal_pose_body(pose):
mass = 1
inertia = pymunk.moment_for_box(mass, (50, 100))
body = pymunk.Body(mass, inertia)
# preserving the legacy assignment order for compatibility
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
body.position = pose[:2].tolist()
body.angle = pose[2]
return body
def add_segment(space, a, b, radius):
shape = pymunk.Segment(space.static_body, a, b, radius)
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
return shape
def add_tee(
space,
position,
angle,
scale=30,
color="LightSlateGray",
mask=None,
):
if mask is None:
mask = pymunk.ShapeFilter.ALL_MASKS()
mass = 1
length = 4
vertices1 = [
(-length * scale / 2, scale),
(length * scale / 2, scale),
(length * scale / 2, 0),
(-length * scale / 2, 0),
]
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
vertices2 = [
(-scale / 2, scale),
(-scale / 2, length * scale),
(scale / 2, length * scale),
(scale / 2, scale),
]
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
body = pymunk.Body(mass, inertia1 + inertia2)
shape1 = pymunk.Poly(body, vertices1)
shape2 = pymunk.Poly(body, vertices2)
shape1.color = pygame.Color(color)
shape2.color = pygame.Color(color)
shape1.filter = pymunk.ShapeFilter(mask=mask)
shape2.filter = pymunk.ShapeFilter(mask=mask)
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
body.position = position
body.angle = angle
body.friction = 1
space.add(body, shape1, shape2)
return body
class PushtDataset(torch.utils.data.Dataset): class PushtDataset(torch.utils.data.Dataset):
""" """
@ -121,7 +60,7 @@ class PushtDataset(torch.utils.data.Dataset):
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
return len(self.data_dict["index"]) return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
@ -158,6 +97,13 @@ class PushtDataset(torch.utils.data.Dataset):
return item return item
def _download_and_preproc_obsolete(self): def _download_and_preproc_obsolete(self):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
raise e
assert self.root is not None assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw" raw_dir = self.root / f"{self.dataset_id}_raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve() zarr_path = (raw_dir / PUSHT_ZARR).resolve()
@ -182,7 +128,7 @@ class PushtDataset(torch.utils.data.Dataset):
# TODO: verify that goal pose is expected to be fixed # TODO: verify that goal pose is expected to be fixed
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
goal_body = get_goal_pose_body(goal_pos_angle) goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
imgs = torch.from_numpy(dataset_dict["img"]) imgs = torch.from_numpy(dataset_dict["img"])
imgs = einops.rearrange(imgs, "b h w c -> b c h w") imgs = einops.rearrange(imgs, "b h w c -> b c h w")
@ -201,6 +147,9 @@ class PushtDataset(torch.utils.data.Dataset):
assert (episode_ids[idx0:idx1] == episode_id).all() assert (episode_ids[idx0:idx1] == episode_id).all()
image = imgs[idx0:idx1] image = imgs[idx0:idx1]
assert image.min() >= 0.0
assert image.max() <= 255.0
image = image.type(torch.uint8)
state = states[idx0:idx1] state = states[idx0:idx1]
agent_pos = state[:, :2] agent_pos = state[:, :2]
@ -217,14 +166,14 @@ class PushtDataset(torch.utils.data.Dataset):
# Add walls. # Add walls.
walls = [ walls = [
add_segment(space, (5, 506), (5, 5), 2), PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
add_segment(space, (5, 5), (506, 5), 2), PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
add_segment(space, (506, 5), (506, 506), 2), PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
add_segment(space, (5, 506), (506, 506), 2), PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
] ]
space.add(*walls) space.add(*walls)
block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item()) block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes) block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area intersection_area = goal_geom.intersection(block_geom).area
@ -265,16 +214,3 @@ class PushtDataset(torch.utils.data.Dataset):
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
self.data_dict["index"] = torch.arange(0, total_frames, 1) self.data_dict["index"] = torch.arange(0, total_frames, 1)
if __name__ == "__main__":
dataset = PushtDataset(
"pusht",
root=Path("data"),
delta_timestamps={
"observation.image": [0, -1, -0.2, -0.1],
"observation.state": [0, -1, -0.2, -0.1],
"action": [-0.1, 0, 1, 2, 3],
},
)
dataset[10]

View File

@ -1,5 +1,4 @@
import io import io
import logging
import zipfile import zipfile
from copy import deepcopy from copy import deepcopy
from math import ceil from math import ceil
@ -35,52 +34,56 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
return False return False
def euclidean_distance_matrix(mat0, mat1):
# Compute the square of the distance matrix
sq0 = torch.sum(mat0**2, dim=1, keepdim=True)
sq1 = torch.sum(mat1**2, dim=1, keepdim=True)
distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1)
# Taking the square root to get the euclidean distance
distance = torch.sqrt(torch.clamp(distance_sq, min=0))
return distance
def is_contiguously_true_or_false(bool_vector):
assert bool_vector.ndim == 1
assert bool_vector.dtype == torch.bool
# Compare each element with its neighbor to find changes
changes = bool_vector[1:] != bool_vector[:-1]
# Count the number of changes
num_changes = changes.sum().item()
# If there's more than one change, the list is not contiguous
return num_changes <= 1
# examples = [
# ([True, False, True, False, False, False], False),
# ([True, True, True, False, False, False], True),
# ([False, False, False, False, False, False], True)
# ]
# for bool_list, expected in examples:
# result = is_contiguously_true_or_false(bool_list)
def load_data_with_delta_timestamps( def load_data_with_delta_timestamps(
data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode data_dict: dict[torch.Tensor],
data_ids_per_episode: dict[torch.Tensor],
delta_timestamps: list[float],
key: str,
current_ts: float,
episode: int,
tol: float = 0.04,
): ):
"""
Given a current timestamp (e.g. current_ts=0.6) and a list of timestamps differences (e.g. delta_timestamps=[-0.8, -0.2, 0, 0.2]),
this function compute the query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames of the specified modality (e.g. key="observation.image").
Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError.
When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp,
the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range.
For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode,
or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode.
Parameters:
- data_dict (dict): A dictionary containing the data, where each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- data_ids_per_episode (dict): A dictionary where keys are episode identifiers and values are lists of indices corresponding to frames associated with each episode.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible key to be retrieved. These deltas are added to the current_ts to form the query timestamps.
- key (str): The key specifying which data modality is to be retrieved from the data_dict.
- current_ts (float): The current timestamp to which the delta timestamps are added to form the query timestamps.
- episode (int): The identifier of the episode from which frames are to be retrieved.
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
Returns:
- tuple: A tuple containing two elements:
- The first element is the data retrieved from the specified modality based on the closest match to the query timestamps.
- The second element is a boolean array indicating which frames were considered as padding (True if the distance to the closest timestamp was greater than the tolerance level).
Raises:
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
"""
# get indices of the frames associated to the episode, and their timestamps # get indices of the frames associated to the episode, and their timestamps
ep_data_ids = data_ids_per_episode[episode] ep_data_ids = data_ids_per_episode[episode]
ep_timestamps = data_dict["timestamp"][ep_data_ids] ep_timestamps = data_dict["timestamp"][ep_data_ids]
# we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0]
ep_last_ts = ep_timestamps[-1]
# get timestamps used as query to retrieve data of previous/future frames # get timestamps used as query to retrieve data of previous/future frames
delta_ts = delta_timestamps[key] delta_ts = delta_timestamps[key]
query_ts = current_ts + torch.tensor(delta_ts) query_ts = current_ts + torch.tensor(delta_ts)
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode # compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None]) dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
min_, argmin_ = dist.min(1) min_, argmin_ = dist.min(1)
# get the indices of the data that are closest to the query timestamps # get the indices of the data that are closest to the query timestamps
@ -92,24 +95,29 @@ def load_data_with_delta_timestamps(
# TODO(rcadene): synchronize timestamps + interpolation if needed # TODO(rcadene): synchronize timestamps + interpolation if needed
tol = 0.04
is_pad = min_ > tol is_pad = min_ > tol
assert is_contiguously_true_or_false(is_pad), ( # check violated query timestamps are all outside the episode range
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=})." assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
"This might be due to synchronization issues with timestamps during data collection." "This might be due to synchronization issues with timestamps during data collection."
) )
return data, is_pad return data, is_pad
def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): def get_stats_einops_patterns(dataset):
stats_path = dataset.data_dir / "stats.pth" """These einops patterns will be used to aggregate batches and compute statistics."""
if stats_path.exists(): stats_patterns = {
return torch.load(stats_path) "action": "b c -> c",
"observation.state": "b c -> c",
}
for key in dataset.image_keys:
stats_patterns[key] = "b c h w -> c 1 1"
return stats_patterns
logging.info(f"compute_stats and save to {stats_path}")
def compute_stats(dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None: if max_num_samples is None:
max_num_samples = len(dataset) max_num_samples = len(dataset)
else: else:
@ -124,13 +132,8 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
drop_last=False, drop_last=False,
) )
# these einops patterns will be used to aggregate batches and compute statistics # get einops patterns to aggregate batches and compute statistics
stats_patterns = { stats_patterns = get_stats_einops_patterns(dataset)
"action": "b c -> c",
"observation.state": "b c -> c",
}
for key in dataset.image_keys:
stats_patterns[key] = "b c h w -> c 1 1"
# mean and std will be computed incrementally while max and min will track the running value. # mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {} mean, std, max, min = {}, {}, {}, {}
@ -201,7 +204,6 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
"min": min[key], "min": min[key],
} }
torch.save(stats, stats_path)
return stats return stats

View File

@ -60,7 +60,7 @@ class XarmDataset(torch.utils.data.Dataset):
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
return len(self.data_dict["index"]) return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
@ -126,7 +126,8 @@ class XarmDataset(torch.utils.data.Dataset):
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
action = torch.tensor(dataset_dict["actions"][idx0:idx1]) action = torch.tensor(dataset_dict["actions"][idx0:idx1])
# TODO(rcadene): concat the last "next_observations" to "observations" # TODO(rcadene): we have a missing last frame which is the observation when the env is done
# it is critical to have this frame for tdmpc to predict a "done observation/state"
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1]) # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1]) # next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])

View File

@ -19,6 +19,7 @@ def preprocess_observation(observation, transform=None):
img = einops.rearrange(img, "b h w c -> b c h w") img = einops.rearrange(img, "b h w c -> b c h w")
obs[imgkey] = img obs[imgkey] = img
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float() obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
# apply same transforms as in training # apply same transforms as in training

View File

@ -29,9 +29,9 @@ def make_policy(cfg):
if cfg.policy.pretrained_model_path: if cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm # TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path: if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path: if "offline" in cfg.policy.pretrained_model_path:
policy.step[0] = 25000 policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path: elif "final" in cfg.policy.pretrained_model_path:
policy.step[0] = 100000 policy.step[0] = 100000
else: else:
raise NotImplementedError() raise NotImplementedError()

View File

@ -333,94 +333,6 @@ class TDMPCPolicy(nn.Module):
"""Main update function. Corresponds to one iteration of the model learning.""" """Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time() start_time = time.time()
# num_slices = self.cfg.batch_size
# batch_size = self.cfg.horizon * num_slices
# if demo_buffer is None:
# demo_batch_size = 0
# else:
# # Update oversampling ratio
# demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
# demo_num_slices = int(demo_pc_batch * self.batch_size)
# demo_batch_size = self.cfg.horizon * demo_num_slices
# batch_size -= demo_batch_size
# num_slices -= demo_num_slices
# replay_buffer._sampler.num_slices = num_slices
# demo_buffer._sampler.num_slices = demo_num_slices
# assert demo_batch_size % self.cfg.horizon == 0
# assert demo_batch_size % demo_num_slices == 0
# assert batch_size % self.cfg.horizon == 0
# assert batch_size % num_slices == 0
# # Sample from interaction dataset
# def process_batch(batch, horizon, num_slices):
# # trajectory t = 256, horizon h = 5
# # (t h) ... -> h t ...
# batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
# obs = {
# "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
# "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
# }
# action = batch["action"].to(self.device, non_blocking=True)
# next_obses = {
# "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
# "state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
# }
# reward = batch["next", "reward"].to(self.device, non_blocking=True)
# idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
# weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
# # TODO(rcadene): rearrange directly in offline dataset
# if reward.ndim == 2:
# reward = einops.rearrange(reward, "h t -> h t 1")
# assert reward.ndim == 3
# assert reward.shape == (horizon, num_slices, 1)
# # We dont use `batch["next", "done"]` since it only indicates the end of an
# # episode, but not the end of the trajectory of an episode.
# # Neither does `batch["next", "terminated"]`
# done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
# mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
# return obs, action, next_obses, reward, mask, done, idxs, weights
# batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
# obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
# batch, self.cfg.horizon, num_slices
# )
# Sample from demonstration dataset
# if demo_batch_size > 0:
# demo_batch = demo_buffer.sample(demo_batch_size)
# (
# demo_obs,
# demo_action,
# demo_next_obses,
# demo_reward,
# demo_mask,
# demo_done,
# demo_idxs,
# demo_weights,
# ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
# if isinstance(obs, dict):
# obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
# next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
# else:
# obs = torch.cat([obs, demo_obs])
# next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
# action = torch.cat([action, demo_action], dim=1)
# reward = torch.cat([reward, demo_reward], dim=1)
# mask = torch.cat([mask, demo_mask], dim=1)
# done = torch.cat([done, demo_done], dim=1)
# idxs = torch.cat([idxs, demo_idxs])
# weights = torch.cat([weights, demo_weights])
batch_size = batch["index"].shape[0] batch_size = batch["index"].shape[0]
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels) # TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
@ -534,6 +446,7 @@ class TDMPCPolicy(nn.Module):
) )
self.optim.step() self.optim.step()
# TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion
# if self.cfg.per: # if self.cfg.per:
# # Update priorities # # Update priorities
# priorities = priority_loss.clamp(max=1e4).detach() # priorities = priority_loss.clamp(max=1e4).detach()

View File

@ -99,6 +99,7 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
def print_cuda_memory_usage(): def print_cuda_memory_usage():
"""Use this function to locate and debug memory leak."""
import gc import gc
gc.collect() gc.collect()

View File

@ -18,7 +18,6 @@ env:
from_pixels: True from_pixels: True
pixels_only: False pixels_only: False
image_size: [3, 480, 640] image_size: [3, 480, 640]
action_repeat: 1
episode_length: 400 episode_length: 400
fps: ${fps} fps: ${fps}

View File

@ -18,7 +18,6 @@ env:
from_pixels: True from_pixels: True
pixels_only: False pixels_only: False
image_size: 96 image_size: 96
action_repeat: 1
episode_length: 300 episode_length: 300
fps: ${fps} fps: ${fps}

View File

@ -17,7 +17,6 @@ env:
from_pixels: True from_pixels: True
pixels_only: False pixels_only: False
image_size: 84 image_size: 84
# action_repeat: 2 # we can remove if policy has n_action_steps=2
episode_length: 25 episode_length: 25
fps: ${fps} fps: ${fps}

View File

@ -36,6 +36,7 @@ policy:
log_std_max: 2 log_std_max: 2
# learning # learning
batch_size: 256
max_buffer_size: 10000 max_buffer_size: 10000
horizon: 5 horizon: 5
reward_coef: 0.5 reward_coef: 0.5

View File

@ -32,6 +32,7 @@ import json
import logging import logging
import threading import threading
import time import time
from copy import deepcopy
from datetime import datetime as dt from datetime import datetime as dt
from pathlib import Path from pathlib import Path
@ -56,15 +57,15 @@ def write_video(video_path, stacked_frames, fps):
def eval_policy( def eval_policy(
env: gym.vector.VectorEnv, env: gym.vector.VectorEnv,
policy, policy: torch.nn.Module,
save_video: bool = False, max_episodes_rendered: int = 0,
video_dir: Path = None, video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps # TODO(rcadene): make it possible to overwrite fps? we should use env.fps
fps: int = 15,
return_first_video: bool = False,
transform: callable = None, transform: callable = None,
seed=None, seed=None,
): ):
fps = env.unwrapped.metadata["render_fps"]
if policy is not None: if policy is not None:
policy.eval() policy.eval()
device = "cpu" if policy is None else next(policy.parameters()).device device = "cpu" if policy is None else next(policy.parameters()).device
@ -83,14 +84,11 @@ def eval_policy(
# needed as I'm currently taking a ceil. # needed as I'm currently taking a ceil.
ep_frames = [] ep_frames = []
def maybe_render_frame(env): def render_frame(env):
if save_video: # noqa: B023 # noqa: B023
if return_first_video: eps_rendered = min(max_episodes_rendered, len(env.envs))
visu = env.envs[0].render() visu = np.stack([env.envs[i].render() for i in range(eps_rendered)])
visu = visu[None, ...] # add batch dim ep_frames.append(visu) # noqa: B023
else:
visu = np.stack([env.render() for env in env.envs])
ep_frames.append(visu) # noqa: B023
for _ in range(num_episodes): for _ in range(num_episodes):
seeds.append("TODO") seeds.append("TODO")
@ -104,8 +102,14 @@ def eval_policy(
# reset the environment # reset the environment
observation, info = env.reset(seed=seed) observation, info = env.reset(seed=seed)
maybe_render_frame(env) if max_episodes_rendered > 0:
render_frame(env)
observations = []
actions = []
# episode
# frame_id
# timestamp
rewards = [] rewards = []
successes = [] successes = []
dones = [] dones = []
@ -113,8 +117,13 @@ def eval_policy(
done = torch.tensor([False for _ in env.envs]) done = torch.tensor([False for _ in env.envs])
step = 0 step = 0
while not done.all(): while not done.all():
# format from env keys to lerobot keys
observation = preprocess_observation(observation)
observations.append(deepcopy(observation))
# apply transform to normalize the observations # apply transform to normalize the observations
observation = preprocess_observation(observation, transform) for key in observation:
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
# send observation to device/gpu # send observation to device/gpu
observation = {key: observation[key].to(device, non_blocking=True) for key in observation} observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
@ -126,11 +135,13 @@ def eval_policy(
# apply inverse transform to unnormalize the action # apply inverse transform to unnormalize the action
action = postprocess_action(action, transform) action = postprocess_action(action, transform)
# apply the next # apply the next action
observation, reward, terminated, truncated, info = env.step(action) observation, reward, terminated, truncated, info = env.step(action)
maybe_render_frame(env) if max_episodes_rendered > 0:
render_frame(env)
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?) # TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
action = torch.from_numpy(action)
reward = torch.from_numpy(reward) reward = torch.from_numpy(reward)
terminated = torch.from_numpy(terminated) terminated = torch.from_numpy(terminated)
truncated = torch.from_numpy(truncated) truncated = torch.from_numpy(truncated)
@ -147,12 +158,24 @@ def eval_policy(
success = [False for _ in env.envs] success = [False for _ in env.envs]
success = torch.tensor(success) success = torch.tensor(success)
actions.append(action)
rewards.append(reward) rewards.append(reward)
dones.append(done) dones.append(done)
successes.append(success) successes.append(success)
step += 1 step += 1
env.close()
# add the last observation when the env is done
observation = preprocess_observation(observation)
observations.append(deepcopy(observation))
new_obses = {}
for key in observations[0].keys(): # noqa: SIM118
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
observations = new_obses
actions = torch.stack(actions, dim=1)
rewards = torch.stack(rewards, dim=1) rewards = torch.stack(rewards, dim=1)
successes = torch.stack(successes, dim=1) successes = torch.stack(successes, dim=1)
dones = torch.stack(dones, dim=1) dones = torch.stack(dones, dim=1)
@ -172,29 +195,61 @@ def eval_policy(
max_rewards.extend(batch_max_reward.tolist()) max_rewards.extend(batch_max_reward.tolist())
all_successes.extend(batch_success.tolist()) all_successes.extend(batch_success.tolist())
env.close() # similar logic is implemented in dataset preprocessing
ep_dicts = []
num_episodes = dones.shape[0]
total_frames = 0
idx0 = idx1 = 0
data_ids_per_episode = {}
for ep_id in range(num_episodes):
num_frames = done_indices[ep_id].item() + 1
# TODO(rcadene): We need to add a missing last frame which is the observation
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
ep_dict = {
"action": actions[ep_id, :num_frames],
"episode": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
}
for key in observations:
ep_dict[key] = observations[key][ep_id, :num_frames]
ep_dicts.append(ep_dict)
if save_video or return_first_video: total_frames += num_frames
idx1 += num_frames
data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
idx0 = idx1
# similar logic is implemented in dataset preprocessing
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(0, total_frames, 1)
if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *) batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
if save_video: for stacked_frames, done_index in zip(
for stacked_frames, done_index in zip( batch_stacked_frames, done_indices.flatten().tolist(), strict=False
batch_stacked_frames, done_indices.flatten().tolist(), strict=False ):
): if episode_counter >= num_episodes:
if episode_counter >= num_episodes: continue
continue video_dir.mkdir(parents=True, exist_ok=True)
video_dir.mkdir(parents=True, exist_ok=True) video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
video_path = video_dir / f"eval_episode_{episode_counter}.mp4" thread = threading.Thread(
thread = threading.Thread( target=write_video,
target=write_video, args=(str(video_path), stacked_frames[:done_index], fps),
args=(str(video_path), stacked_frames[:done_index], fps), )
) thread.start()
thread.start() threads.append(thread)
threads.append(thread) episode_counter += 1
episode_counter += 1
if return_first_video: videos = einops.rearrange(batch_stacked_frames, "b t h w c -> b t c h w")
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
for thread in threads: for thread in threads:
thread.join() thread.join()
@ -225,9 +280,13 @@ def eval_policy(
"eval_s": time.time() - start, "eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes, "eval_ep_s": (time.time() - start) / num_episodes,
}, },
"episodes": {
"data_dict": data_dict,
"data_ids_per_episode": data_ids_per_episode,
},
} }
if return_first_video: if max_episodes_rendered > 0:
return info, first_video info["videos"] = videos
return info return info
@ -253,16 +312,14 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making environment.") logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
# when policy is None, rollout a random policy logging.info("Making policy.")
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None policy = make_policy(cfg)
info = eval_policy( info = eval_policy(
env, env,
policy=policy, policy,
save_video=True, max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps,
# TODO(rcadene): what should we do with the transform?
transform=transform, transform=transform,
seed=cfg.seed, seed=cfg.seed,
) )
@ -270,6 +327,9 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
# Save info # Save info
with open(Path(out_dir) / "eval_info.json", "w") as f: with open(Path(out_dir) / "eval_info.json", "w") as f:
# remove pytorch tensors which are not serializable to save the evaluation results only
del info["episodes"]
del info["videos"]
json.dump(info, f, indent=2) json.dump(info, f, indent=2)
logging.info("End of eval") logging.info("End of eval")

View File

@ -1,8 +1,8 @@
import logging import logging
from copy import deepcopy
from pathlib import Path from pathlib import Path
import hydra import hydra
import numpy as np
import torch import torch
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
@ -108,6 +108,64 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
logger.log_dict(info, step, mode="eval") logger.log_dict(info, step, mode="eval")
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
"""
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
Parameters:
- n_off (int): Number of offline samples, each with a sampling weight of 1.
- n_on (int): Number of online samples.
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
The total weight of offline samples is n_off * 1.0.
The total weight of offline samples is n_on * w.
The total combined weight of all samples is n_off + n_on * w.
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
"""
assert 0.0 <= pc_on <= 1.0
return -(n_off * pc_on) / (n_on * (pc_on - 1))
def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples):
data_dict = episodes["data_dict"]
data_ids_per_episode = episodes["data_ids_per_episode"]
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.data_dict = data_dict
online_dataset.data_ids_per_episode = data_ids_per_episode
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1
start_index = online_dataset.data_dict["index"][-1].item() + 1
data_dict["episode"] += start_episode
data_dict["index"] += start_index
# extend online dataset
for key in data_dict:
# TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure
online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]])
for ep_id in data_ids_per_episode:
online_dataset.data_ids_per_episode[ep_id + start_episode] = (
data_ids_per_episode[ep_id] + start_index
)
# update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
len_online = len(online_dataset)
len_offline = len(concat_dataset) - len_online
weight_offline = 1.0
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
# update the total number of samples used during sampling
sampler.num_samples = len(concat_dataset)
def train(cfg: dict, out_dir=None, job_name=None): def train(cfg: dict, out_dir=None, job_name=None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
@ -126,26 +184,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
set_global_seed(cfg.seed) set_global_seed(cfg.seed)
logging.info("make_dataset") logging.info("make_dataset")
dataset = make_dataset(cfg) offline_dataset = make_dataset(cfg)
# TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
# if cfg.policy.balanced_sampling:
# logging.info("make online_buffer")
# num_traj_per_batch = cfg.policy.batch_size
# online_sampler = PrioritizedSliceSampler(
# max_capacity=100_000,
# alpha=cfg.policy.per_alpha,
# beta=cfg.policy.per_beta,
# num_slices=num_traj_per_batch,
# strict_length=True,
# )
# online_buffer = TensorDictReplayBuffer(
# storage=LazyMemmapStorage(100_000),
# sampler=online_sampler,
# transform=dataset.transform,
# )
logging.info("make_env") logging.info("make_env")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
@ -163,9 +202,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}") logging.info(f"{cfg.online_steps=}")
logging.info(f"{cfg.env.action_repeat=}") logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})") logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
@ -173,18 +211,17 @@ def train(cfg: dict, out_dir=None, job_name=None):
def _maybe_eval_and_maybe_save(step): def _maybe_eval_and_maybe_save(step):
if step % cfg.eval_freq == 0: if step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
eval_info, first_video = eval_policy( eval_info = eval_policy(
env, env,
policy, policy,
return_first_video=True,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
save_video=True, max_episodes_rendered=4,
transform=dataset.transform, transform=offline_dataset.transform,
seed=cfg.seed, seed=cfg.seed,
) )
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable: if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval") logger.log_video(eval_info["videos"][0], step, mode="eval")
logging.info("Resume training") logging.info("Resume training")
if cfg.save_model and step % cfg.save_freq == 0: if cfg.save_model and step % cfg.save_freq == 0:
@ -192,18 +229,19 @@ def train(cfg: dict, out_dir=None, job_name=None):
logger.save_model(policy, identifier=step) logger.save_model(policy, identifier=step)
logging.info("Resume training") logging.info("Resume training")
step = 0 # number of policy update (forward + backward + optim) # create dataloader for offline training
is_offline = True
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, offline_dataset,
num_workers=4, num_workers=4,
batch_size=cfg.policy.batch_size, batch_size=cfg.policy.batch_size,
shuffle=True, shuffle=True,
pin_memory=cfg.device != "cpu", pin_memory=cfg.device != "cpu",
drop_last=True, drop_last=False,
) )
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
step = 0 # number of policy update (forward + backward + optim)
is_offline = True
for offline_step in range(cfg.offline_steps): for offline_step in range(cfg.offline_steps):
if offline_step == 0: if offline_step == 0:
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
@ -217,7 +255,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0: if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, dataset, is_offline) log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
# step + 1. # step + 1.
@ -225,61 +263,60 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1 step += 1
raise NotImplementedError() # create an env dedicated to online episodes collection from policy rollout
rollout_env = make_env(cfg, num_parallel_envs=1)
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.data_dict = {}
online_dataset.data_ids_per_episode = {}
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
weights = [1.0] * len(concat_dataset)
sampler = torch.utils.data.WeightedRandomSampler(
weights, num_samples=len(concat_dataset), replacement=True
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
sampler=sampler,
pin_memory=cfg.device != "cpu",
drop_last=False,
)
dl_iter = cycle(dataloader)
demo_buffer = dataset if cfg.policy.balanced_sampling else None
online_step = 0 online_step = 0
is_offline = False is_offline = False
for env_step in range(cfg.online_steps): for env_step in range(cfg.online_steps):
if env_step == 0: if env_step == 0:
logging.info("Start online training by interacting with environment") logging.info("Start online training by interacting with environment")
# TODO: add configurable number of rollout? (default=1)
with torch.no_grad(): with torch.no_grad():
rollout = env.rollout( eval_info = eval_policy(
max_steps=cfg.env.episode_length, rollout_env,
policy=policy, policy,
auto_cast_to_device=True, transform=offline_dataset.transform,
seed=cfg.seed,
) )
assert ( online_pc_sampling = cfg.get("demo_schedule", 0.5)
len(rollout.batch_size) == 2 add_episodes_inplace(
), "2 dimensions expected: number of env in parallel x max number of steps during rollout" eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling
)
num_parallel_env = rollout.batch_size[0]
if num_parallel_env != 1:
# TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests
raise NotImplementedError()
num_max_steps = rollout.batch_size[1]
assert num_max_steps <= cfg.env.episode_length
# reshape to have a list of steps to insert into online_buffer
rollout = rollout.reshape(num_parallel_env * num_max_steps)
# set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
# online_buffer.extend(rollout)
ep_sum_reward = rollout["next", "reward"].sum()
ep_max_reward = rollout["next", "reward"].max()
ep_success = rollout["next", "success"].any()
rollout_info = {
"avg_sum_reward": np.nanmean(ep_sum_reward),
"avg_max_reward": np.nanmean(ep_max_reward),
"pc_success": np.nanmean(ep_success) * 100,
"env_step": env_step,
"ep_length": len(rollout),
}
for _ in range(cfg.policy.utd): for _ in range(cfg.policy.utd):
train_info = policy.update( policy.train()
# online_buffer, batch = next(dl_iter)
step,
demo_buffer=demo_buffer, for key in batch:
) batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy(batch, step)
if step % cfg.log_freq == 0: if step % cfg.log_freq == 0:
train_info.update(rollout_info) log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
# in step + 1. # in step + 1.

View File

@ -6,9 +6,6 @@ import einops
import hydra import hydra
import imageio import imageio
import torch import torch
from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
)
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.logger import log_output_dir from lerobot.common.logger import log_output_dir
@ -39,19 +36,11 @@ def visualize_dataset(cfg: dict, out_dir=None):
init_logging() init_logging()
log_output_dir(out_dir) log_output_dir(out_dir)
# we expect frames of each episode to be stored next to each others sequentially
sampler = SamplerWithoutReplacement(
shuffle=False,
)
logging.info("make_dataset") logging.info("make_dataset")
dataset = make_dataset( dataset = make_dataset(
cfg, cfg,
overwrite_sampler=sampler,
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization # remove all transformations such as rescale images from [0,255] to [0,1] or normalization
normalize=False, normalize=False,
overwrite_batch_size=1,
overwrite_prefetch=12,
) )
logging.info("Start rendering episodes from offline buffer") logging.info("Start rendering episodes from offline buffer")
@ -60,64 +49,49 @@ def visualize_dataset(cfg: dict, out_dir=None):
logging.info(video_path) logging.info(video_path)
def render_dataset(dataset, out_dir, max_num_samples, fps): def render_dataset(dataset, out_dir, max_num_episodes):
out_dir = Path(out_dir) out_dir = Path(out_dir)
video_paths = [] video_paths = []
threads = [] threads = []
frames = {}
current_ep_idx = 0
logging.info(f"Visualizing episode {current_ep_idx}")
for i in range(max_num_samples):
# TODO(rcadene): make it work with bsize > 1
ep_td = dataset.sample(1)
ep_idx = ep_td["episode"][FIRST_FRAME].item()
# TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames dataloader = torch.utils.data.DataLoader(
num_frames_left = dataset._sampler._sample_list.numel() dataset,
episode_is_done = ep_idx != current_ep_idx num_workers=4,
batch_size=1,
shuffle=False,
)
dl_iter = iter(dataloader)
if episode_is_done: num_episodes = len(dataset.data_ids_per_episode)
logging.info(f"Rendering episode {current_ep_idx}") for ep_id in range(min(max_num_episodes, num_episodes)):
logging.info(f"Rendering episode {ep_id}")
for im_key in dataset.image_keys: frames = {}
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1): for _ in dataset.data_ids_per_episode[ep_id]:
item = next(dl_iter)
for im_key in dataset.image_keys:
# when first frame of episode, initialize frames dict # when first frame of episode, initialize frames dict
if im_key not in frames: if im_key not in frames:
frames[im_key] = [] frames[im_key] = []
# add current frame to list of frames to render # add current frame to list of frames to render
frames[im_key].append(ep_td[im_key]) frames[im_key].append(item[im_key])
out_dir.mkdir(parents=True, exist_ok=True)
for im_key in dataset.image_keys:
if len(dataset.image_keys) > 1:
im_name = im_key.replace("observation.images.", "")
video_path = out_dir / f"episode_{ep_id}_{im_name}.mp4"
else: else:
# When episode has no more frame in its list of observation, video_path = out_dir / f"episode_{ep_id}.mp4"
# one frame still remains. It is the result of the last action taken. video_paths.append(video_path)
# It is stored in `"next"`, so we add it to the list of frames to render.
frames[im_key].append(ep_td["next"][im_key])
out_dir.mkdir(parents=True, exist_ok=True) thread = threading.Thread(
if len(dataset.image_keys) > 1: target=cat_and_write_video,
camera = im_key[-1] args=(str(video_path), frames[im_key], dataset.fps),
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4" )
else: thread.start()
video_path = out_dir / f"episode_{current_ep_idx}.mp4" threads.append(thread)
video_paths.append(str(video_path))
thread = threading.Thread(
target=cat_and_write_video,
args=(str(video_path), frames[im_key], fps),
)
thread.start()
threads.append(thread)
current_ep_idx = ep_idx
# reset list of frames
del frames[im_key]
if num_frames_left == 0:
logging.info("Ran out of frames")
break
if current_ep_idx == NUM_EPISODES_TO_RENDER:
break
for thread in threads: for thread in threads:
thread.join() thread.join()

6
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
[[package]] [[package]]
name = "absl-py" name = "absl-py"
@ -921,7 +921,7 @@ shapely = "^2.0.3"
type = "git" type = "git"
url = "git@github.com:huggingface/gym-pusht.git" url = "git@github.com:huggingface/gym-pusht.git"
reference = "HEAD" reference = "HEAD"
resolved_reference = "824b22832cc8d71a4b4e96a57563510cf47e30c1" resolved_reference = "080d4ce4d8d3140b2fd204ed628bda14dc58ff06"
[[package]] [[package]]
name = "gym-xarm" name = "gym-xarm"
@ -941,7 +941,7 @@ mujoco = "^2.3.7"
type = "git" type = "git"
url = "git@github.com:huggingface/gym-xarm.git" url = "git@github.com:huggingface/gym-xarm.git"
reference = "HEAD" reference = "HEAD"
resolved_reference = "ce294c0d30def08414d9237e2bf9f373d448ca07" resolved_reference = "6a88f7d63833705dfbec4b997bf36cac6b4a448c"
[[package]] [[package]]
name = "gymnasium" name = "gymnasium"

View File

@ -1,25 +0,0 @@
#!/bin/bash
#SBATCH --nodes=1 # total number of nodes (N to be defined)
#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU)
#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs)
#SBATCH --cpus-per-task=8 # number of cores per task (8x8 = 64 cores, or all the cores)
#SBATCH --time=2-00:00:00
#SBATCH --output=/home/rcadene/slurm/%j.out
#SBATCH --error=/home/rcadene/slurm/%j.err
#SBATCH --qos=low
#SBATCH --mail-user=re.cadene@gmail.com
#SBATCH --mail-type=ALL
CMD=$@
echo "command: $CMD"
apptainer exec --nv \
~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL
source ~/.bashrc
#conda activate fowm
conda activate lerobot
export DATA_DIR="data"
srun $CMD

View File

@ -1,17 +0,0 @@
#!/bin/bash
#SBATCH --nodes=1 # total number of nodes (N to be defined)
#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU)
#SBATCH --qos=normal # number of GPUs reserved per node (here 8, or all the GPUs)
#SBATCH --partition=hopper-prod
#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs)
#SBATCH --cpus-per-task=12 # number of cores per task
#SBATCH --mem-per-cpu=11G
#SBATCH --time=12:00:00
#SBATCH --output=/admin/home/remi_cadene/slurm/%j.out
#SBATCH --error=/admin/home/remi_cadene/slurm/%j.err
#SBATCH --mail-user=remi_cadene@huggingface.co
#SBATCH --mail-type=ALL
CMD=$@
echo "command: $CMD"
srun $CMD

Binary file not shown.

View File

@ -1,64 +1,53 @@
""" """
This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully
imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) corresponds. imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid.
Note: When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - Set the required class attributes: `available_datasets`.
1. set the required class attributes: - Set the required class attributes: `name`.
- for classes inheriting from `AbstractDataset`: `available_datasets` - Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - Update variables in `tests/test_available.py` by importing your new class
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
""" """
import importlib
import pytest import pytest
import lerobot import lerobot
import gymnasium as gym
# from lerobot.common.envs.aloha.env import AlohaEnv from lerobot.common.datasets.xarm import XarmDataset
# from gym_pusht.envs import PushtEnv from lerobot.common.datasets.aloha import AlohaDataset
# from gym_xarm.envs import SimxarmEnv from lerobot.common.datasets.pusht import PushtDataset
# from lerobot.common.datasets.xarm import SimxarmDataset from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
# from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.policies.diffusion.policy import DiffusionPolicy
# from lerobot.common.datasets.pusht import PushtDataset from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
# from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
# from lerobot.common.policies.diffusion.policy import DiffusionPolicy
# from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
# def test_available(): def test_available():
# pol_classes = [ policy_classes = [
# ActionChunkingTransformerPolicy, ActionChunkingTransformerPolicy,
# DiffusionPolicy, DiffusionPolicy,
# TDMPCPolicy, TDMPCPolicy,
# ] ]
# env_classes = [ dataset_class_per_env = {
# AlohaEnv, "aloha": AlohaDataset,
# PushtEnv, "pusht": PushtDataset,
# SimxarmEnv, "xarm": XarmDataset,
# ] }
# dat_classes = [
# AlohaDataset,
# PushtDataset,
# SimxarmDataset,
# ]
# policies = [pol_cls.name for pol_cls in pol_classes] policies = [pol_cls.name for pol_cls in policy_classes]
# assert set(policies) == set(lerobot.available_policies) assert set(policies) == set(lerobot.available_policies), policies
# envs = [env_cls.name for env_cls in env_classes] for env_name in lerobot.available_envs:
# assert set(envs) == set(lerobot.available_envs) for task_name in lerobot.available_tasks_per_env[env_name]:
package_name = f"gym_{env_name}"
importlib.import_module(package_name)
gym_handle = f"{package_name}/{task_name}"
assert gym_handle in gym.envs.registry.keys(), gym_handle
# tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes} dataset_class = dataset_class_per_env[env_name]
# for env in envs: available_datasets = lerobot.available_datasets_per_env[env_name]
# assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env]) assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}"
# datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)}
# for env in envs:
# assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env])

View File

@ -1,6 +1,12 @@
import os
from pathlib import Path
import einops
import pytest import pytest
import torch import torch
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, load_data_with_delta_timestamps
from lerobot.common.datasets.xarm import XarmDataset
from lerobot.common.transforms import Prod
from lerobot.common.utils import init_hydra_config from lerobot.common.utils import init_hydra_config
import logging import logging
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
@ -45,6 +51,7 @@ def test_factory(env_name, dataset_id, policy_name):
keys_ndim_required.append( keys_ndim_required.append(
(key, 3, True), (key, 3, True),
) )
assert dataset.data_dict[key].dtype == torch.uint8, f"{key}"
# test number of dimensions # test number of dimensions
for key, ndim, required in keys_ndim_required: for key, ndim, required in keys_ndim_required:
@ -81,28 +88,104 @@ def test_factory(env_name, dataset_id, policy_name):
assert key in item, f"{key}" assert key in item, f"{key}"
# def test_compute_stats(): def test_compute_stats():
# """Check that the statistics are computed correctly according to the stats_patterns property. """Check that the statistics are computed correctly according to the stats_patterns property.
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset).
"""
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# get transform to convert images from uint8 [0,255] to float32 [0,1]
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
dataset = XarmDataset(
dataset_id="xarm_lift_medium",
root=DATA_DIR,
transform=transform,
)
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
# get all frames from the dataset in the same dtype and range as during compute_stats
data_dict = transform(dataset.data_dict)
# compute stats based on all frames from the dataset without any batching
expected_stats = {}
for k, pattern in stats_patterns.items():
expected_stats[k] = {}
expected_stats[k]["mean"] = einops.reduce(data_dict[k], pattern, "mean")
expected_stats[k]["std"] = torch.sqrt(einops.reduce((data_dict[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean"))
expected_stats[k]["min"] = einops.reduce(data_dict[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(data_dict[k], pattern, "max")
# test computed stats match expected stats
for k in stats_patterns:
assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"])
assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"])
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# TODO(rcadene): check that the stats used for training are correct too
# # load stats that are expected to match the ones returned by computed_stats
# assert (dataset.data_dir / "stats.pth").exists()
# loaded_stats = torch.load(dataset.data_dir / "stats.pth")
# # test loaded stats match expected stats
# for k in stats_patterns:
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
def test_load_data_with_delta_timestamps_within_tolerance():
data_dict = {
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
"index": torch.tensor([0, 1, 2, 3, 4]),
}
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
delta_timestamps = {"index": [-0.2, 0, 0.139]}
key = "index"
current_ts = 0.3
episode = 0
tol = 0.04
data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
assert not is_pad.any(), "Unexpected padding detected"
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
def test_load_data_with_delta_timestamps_outside_tolerance_inside_episode_range():
data_dict = {
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
"index": torch.tensor([0, 1, 2, 3, 4]),
}
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
delta_timestamps = {"index": [-0.2, 0, 0.141]}
key = "index"
current_ts = 0.3
episode = 0
tol = 0.04
with pytest.raises(AssertionError):
load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
def test_load_data_with_delta_timestamps_outside_tolerance_outside_episode_range():
data_dict = {
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
"index": torch.tensor([0, 1, 2, 3, 4]),
}
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
key = "index"
current_ts = 0.3
episode = 0
tol = 0.04
data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values"
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
# We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
# because we are working with a small dataset).
# """
# cfg = init_hydra_config(
# DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"]
# )
# dataset = make_dataset(cfg)
# # Get all of the data.
# all_data = dataset.data_dict
# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# # computation of the statistics. While doing this, we also make sure it works when we don't divide the
# # dataset into even batches.
# computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75))
# for k, pattern in buffer.stats_patterns.items():
# expected_mean = einops.reduce(all_data[k], pattern, "mean")
# assert torch.allclose(computed_stats[k]["mean"], expected_mean)
# assert torch.allclose(
# computed_stats[k]["std"],
# torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))
# )
# assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
# assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))