diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 031c2cd3..2744f595 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -1,26 +1,13 @@ import logging from pathlib import Path -from typing import Callable import einops import gdown import h5py import torch -import torchrl import tqdm -from tensordict import TensorDict -from torchrl.data.replay_buffers.samplers import Sampler -from torchrl.data.replay_buffers.storages import TensorStorage -from torchrl.data.replay_buffers.writers import Writer -from lerobot.common.datasets.abstract import AbstractDataset - -DATASET_IDS = [ - "aloha_sim_insertion_human", - "aloha_sim_insertion_scripted", - "aloha_sim_transfer_cube_human", - "aloha_sim_transfer_cube_scripted", -] +from lerobot.common.datasets.utils import load_data_with_delta_timestamps FOLDER_URLS = { "aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF", @@ -66,7 +53,6 @@ CAMERAS = { def download(data_dir, dataset_id): - assert dataset_id in DATASET_IDS assert dataset_id in FOLDER_URLS assert dataset_id in EP48_URLS assert dataset_id in EP49_URLS @@ -80,51 +66,78 @@ def download(data_dir, dataset_id): gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True) -class AlohaDataset(AbstractDataset): - available_datasets = DATASET_IDS +class AlohaDataset(torch.utils.data.Dataset): + available_datasets = [ + "aloha_sim_insertion_human", + "aloha_sim_insertion_scripted", + "aloha_sim_transfer_cube_human", + "aloha_sim_transfer_cube_scripted", + ] + fps = 50 + image_keys = ["observation.images.top"] def __init__( self, dataset_id: str, version: str | None = "v1.2", - batch_size: int | None = None, - *, - shuffle: bool = True, root: Path | None = None, - pin_memory: bool = False, - prefetch: int = None, - sampler: Sampler | None = None, - collate_fn: Callable | None = None, - writer: Writer | None = None, - transform: "torchrl.envs.Transform" = None, + transform: callable = None, + delta_timestamps: dict[list[float]] | None = None, ): - super().__init__( - dataset_id, - version, - batch_size, - shuffle=shuffle, - root=root, - pin_memory=pin_memory, - prefetch=prefetch, - sampler=sampler, - collate_fn=collate_fn, - writer=writer, - transform=transform, - ) + super().__init__() + self.dataset_id = dataset_id + self.version = version + self.root = root + self.transform = transform + self.delta_timestamps = delta_timestamps + + data_dir = self.root / f"{self.dataset_id}" + if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): + self.data_dict = torch.load(data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") + else: + self._download_and_preproc_obsolete() + data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") @property - def stats_patterns(self) -> dict: - d = { - ("observation", "state"): "b c -> c", - ("action",): "b c -> c", - } - for cam in CAMERAS[self.dataset_id]: - d[("observation", "image", cam)] = "b c h w -> c 1 1" - return d + def num_samples(self) -> int: + return len(self.data_dict["index"]) @property - def image_keys(self) -> list: - return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]] + def num_episodes(self) -> int: + return len(self.data_ids_per_episode) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + item = {} + + # get episode id and timestamp of the sampled frame + current_ts = self.data_dict["timestamp"][idx].item() + episode = self.data_dict["episode"][idx].item() + + for key in self.data_dict: + if self.delta_timestamps is not None and key in self.delta_timestamps: + data, is_pad = load_data_with_delta_timestamps( + self.data_dict, + self.data_ids_per_episode, + self.delta_timestamps, + key, + current_ts, + episode, + ) + item[key] = data + item[f"{key}_is_pad"] = is_pad + else: + item[key] = self.data_dict[key][idx] + + if self.transform is not None: + item = self.transform(item) + + return item def _download_and_preproc_obsolete(self): assert self.root is not None @@ -132,54 +145,55 @@ class AlohaDataset(AbstractDataset): if not raw_dir.is_dir(): download(raw_dir, self.dataset_id) - total_num_frames = 0 + total_frames = 0 logging.info("Compute total number of frames to initialize offline buffer") for ep_id in range(NUM_EPISODES[self.dataset_id]): ep_path = raw_dir / f"episode_{ep_id}.hdf5" with h5py.File(ep_path, "r") as ep: - total_num_frames += ep["/action"].shape[0] - 1 - logging.info(f"{total_num_frames=}") + total_frames += ep["/action"].shape[0] - 1 + logging.info(f"{total_frames=}") + + self.data_ids_per_episode = {} + ep_dicts = [] logging.info("Initialize and feed offline buffer") - idxtd = 0 for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])): ep_path = raw_dir / f"episode_{ep_id}.hdf5" with h5py.File(ep_path, "r") as ep: - ep_num_frames = ep["/action"].shape[0] + num_frames = ep["/action"].shape[0] # last step of demonstration is considered done - done = torch.zeros(ep_num_frames, 1, dtype=torch.bool) + done = torch.zeros(num_frames, 1, dtype=torch.bool) done[-1] = True state = torch.from_numpy(ep["/observations/qpos"][:]) action = torch.from_numpy(ep["/action"][:]) - ep_td = TensorDict( - { - ("observation", "state"): state[:-1], - "action": action[:-1], - "episode": torch.tensor([ep_id] * (ep_num_frames - 1)), - "frame_id": torch.arange(0, ep_num_frames - 1, 1), - ("next", "observation", "state"): state[1:], - # TODO: compute reward and success - # ("next", "reward"): reward[1:], - ("next", "done"): done[1:], - # ("next", "success"): success[1:], - }, - batch_size=ep_num_frames - 1, - ) + ep_dict = { + "observation.state": state, + "action": action, + "episode": torch.tensor([ep_id] * num_frames), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / self.fps, + # "next.observation.state": state, + # TODO(rcadene): compute reward and success + # "next.reward": reward[1:], + "next.done": done[1:], + # "next.success": success[1:], + } for cam in CAMERAS[self.dataset_id]: image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) image = einops.rearrange(image, "b h w c -> b c h w").contiguous() - ep_td["observation", "image", cam] = image[:-1] - ep_td["next", "observation", "image", cam] = image[1:] + ep_dict[f"observation.images.{cam}"] = image[:-1] + # ep_dict[f"next.observation.images.{cam}"] = image[1:] - if ep_id == 0: - # hack to initialize tensordict data structure to store episodes - td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}") + ep_dicts.append(ep_dict) - td_data[idxtd : idxtd + len(ep_td)] = ep_td - idxtd = idxtd + len(ep_td) + self.data_dict = {} - return TensorStorage(td_data.lock_()) + keys = ep_dicts[0].keys() + for key in keys: + self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + + self.data_dict["index"] = torch.arange(0, total_frames, 1) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 04077034..94ac8ca4 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,11 +1,10 @@ -import logging import os from pathlib import Path import torch -from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler +from torchvision.transforms import v2 -from lerobot.common.transforms import NormalizeTransform, Prod +from lerobot.common.transforms import Prod # DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and # we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data` @@ -13,57 +12,12 @@ from lerobot.common.transforms import NormalizeTransform, Prod DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None -def make_offline_buffer( +def make_dataset( cfg, - overwrite_sampler=None, # set normalize=False to remove all transformations and keep images unnormalized in [0,255] normalize=True, - overwrite_batch_size=None, - overwrite_prefetch=None, stats_path=None, ): - if cfg.policy.balanced_sampling: - assert cfg.online_steps > 0 - batch_size = None - pin_memory = False - prefetch = None - else: - assert cfg.online_steps == 0 - num_slices = cfg.policy.batch_size - batch_size = cfg.policy.horizon * num_slices - pin_memory = cfg.device == "cuda" - prefetch = cfg.prefetch - - if overwrite_batch_size is not None: - batch_size = overwrite_batch_size - - if overwrite_prefetch is not None: - prefetch = overwrite_prefetch - - if overwrite_sampler is None: - # TODO(rcadene): move batch_size outside - num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon - # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. - # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. - - if cfg.offline_prioritized_sampler: - logging.info("use prioritized sampler for offline dataset") - sampler = PrioritizedSliceSampler( - max_capacity=100_000, - alpha=cfg.policy.per_alpha, - beta=cfg.policy.per_beta, - num_slices=num_traj_per_batch, - strict_length=False, - ) - else: - logging.info("use simple sampler for offline dataset") - sampler = SliceSampler( - num_slices=num_traj_per_batch, - strict_length=False, - ) - else: - sampler = overwrite_sampler - if cfg.env.name == "simxarm": from lerobot.common.datasets.simxarm import SimxarmDataset @@ -81,56 +35,56 @@ def make_offline_buffer( else: raise ValueError(cfg.env.name) - offline_buffer = clsfunc( - dataset_id=cfg.dataset_id, - sampler=sampler, - batch_size=batch_size, - root=DATA_DIR, - pin_memory=pin_memory, - prefetch=prefetch if isinstance(prefetch, int) else None, - ) - - if cfg.policy.name == "tdmpc": - img_keys = [] - for key in offline_buffer.image_keys: - img_keys.append(("next", *key)) - img_keys += offline_buffer.image_keys - else: - img_keys = offline_buffer.image_keys - + transforms = None if normalize: - transforms = [Prod(in_keys=img_keys, prod=1 / 255)] - # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, # min_max_from_spec - stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path) + # stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path) - # we only normalize the state and action, since the images are usually normalized inside the model for - # now (except for tdmpc: see the following) - in_keys = [("observation", "state"), ("action")] - - if cfg.policy.name == "tdmpc": - # TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now - in_keys += img_keys - # TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now. - in_keys += [("next", *key) for key in img_keys] - in_keys.append(("next", "observation", "state")) + stats = {} if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this - stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) - stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) - stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) - stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) + stats["observation.state"] = {} + stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) + stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) + stats["action"] = {} + stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) + stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std - normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" - transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) + # normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" - offline_buffer.set_transform(transforms) + transforms = v2.Compose( + [ + # TODO(rcadene): we need to do something about image_keys + Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), + # NormalizeTransform( + # stats, + # in_keys=[ + # "observation.state", + # "action", + # ], + # mode=normalization_mode, + # ), + ] + ) - if not overwrite_sampler: - index = torch.arange(0, offline_buffer.num_samples, 1) - sampler.extend(index) + if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": + # TODO(rcadene): implement delta_timestamps in config + delta_timestamps = { + "observation.image": [-0.1, 0], + "observation.state": [-0.1, 0], + "action": [-0.1] + [i / clsfunc.fps for i in range(15)], + } + else: + delta_timestamps = None - return offline_buffer + dataset = clsfunc( + dataset_id=cfg.dataset_id, + root=DATA_DIR, + delta_timestamps=delta_timestamps, + transform=transforms, + ) + + return dataset diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 624fb140..3de70b1f 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -1,20 +1,13 @@ from pathlib import Path -from typing import Callable import einops import numpy as np import pygame import pymunk import torch -import torchrl import tqdm -from tensordict import TensorDict -from torchrl.data.replay_buffers.samplers import Sampler -from torchrl.data.replay_buffers.storages import TensorStorage -from torchrl.data.replay_buffers.writers import Writer -from lerobot.common.datasets.abstract import AbstractDataset -from lerobot.common.datasets.utils import download_and_extract_zip +from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer @@ -83,37 +76,82 @@ def add_tee( return body -class PushtDataset(AbstractDataset): +class PushtDataset(torch.utils.data.Dataset): + """ + + Arguments + ---------- + delta_timestamps : dict[list[float]] | None, optional + Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image) + If `None`, no shift is applied to current timestamp and the data from the current frame is loaded. + """ + available_datasets = ["pusht"] + fps = 10 + image_keys = ["observation.image"] def __init__( self, dataset_id: str, version: str | None = "v1.2", - batch_size: int | None = None, - *, - shuffle: bool = True, root: Path | None = None, - pin_memory: bool = False, - prefetch: int = None, - sampler: Sampler | None = None, - collate_fn: Callable | None = None, - writer: Writer | None = None, - transform: "torchrl.envs.Transform" = None, + transform: callable = None, + delta_timestamps: dict[list[float]] | None = None, ): - super().__init__( - dataset_id, - version, - batch_size, - shuffle=shuffle, - root=root, - pin_memory=pin_memory, - prefetch=prefetch, - sampler=sampler, - collate_fn=collate_fn, - writer=writer, - transform=transform, - ) + super().__init__() + self.dataset_id = dataset_id + self.version = version + self.root = root + self.transform = transform + self.delta_timestamps = delta_timestamps + + data_dir = self.root / f"{self.dataset_id}" + if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): + self.data_dict = torch.load(data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") + else: + self._download_and_preproc_obsolete() + data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") + + @property + def num_samples(self) -> int: + return len(self.data_dict["index"]) + + @property + def num_episodes(self) -> int: + return len(self.data_ids_per_episode) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + item = {} + + # get episode id and timestamp of the sampled frame + current_ts = self.data_dict["timestamp"][idx].item() + episode = self.data_dict["episode"][idx].item() + + for key in self.data_dict: + if self.delta_timestamps is not None and key in self.delta_timestamps: + data, is_pad = load_data_with_delta_timestamps( + self.data_dict, + self.data_ids_per_episode, + self.delta_timestamps, + key, + current_ts, + episode, + ) + item[key] = data + item[f"{key}_is_pad"] = is_pad + else: + item[key] = self.data_dict[key][idx] + + if self.transform is not None: + item = self.transform(item) + + return item def _download_and_preproc_obsolete(self): assert self.root is not None @@ -147,8 +185,10 @@ class PushtDataset(AbstractDataset): states = torch.from_numpy(dataset_dict["state"]) actions = torch.from_numpy(dataset_dict["action"]) + self.data_ids_per_episode = {} + ep_dicts = [] + idx0 = 0 - idxtd = 0 for episode_id in tqdm.tqdm(range(num_episodes)): idx1 = dataset_dict.meta["episode_ends"][episode_id] # to create test artifact @@ -194,30 +234,45 @@ class PushtDataset(AbstractDataset): # last step of demonstration is considered done done[-1] = True - ep_td = TensorDict( - { - ("observation", "image"): image[:-1], - ("observation", "state"): agent_pos[:-1], - "action": actions[idx0:idx1][:-1], - "episode": episode_ids[idx0:idx1][:-1], - "frame_id": torch.arange(0, num_frames - 1, 1), - ("next", "observation", "image"): image[1:], - ("next", "observation", "state"): agent_pos[1:], - # TODO: verify that reward and done are aligned with image and agent_pos - ("next", "reward"): reward[1:], - ("next", "done"): done[1:], - ("next", "success"): success[1:], - }, - batch_size=num_frames - 1, - ) + ep_dict = { + "observation.image": image, + "observation.state": agent_pos, + "action": actions[idx0:idx1], + "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / self.fps, + # "next.observation.image": image[1:], + # "next.observation.state": agent_pos[1:], + # TODO(rcadene): verify that reward and done are aligned with image and agent_pos + "next.reward": torch.cat([reward[1:], reward[[-1]]]), + "next.done": torch.cat([done[1:], done[[-1]]]), + "next.success": torch.cat([success[1:], success[[-1]]]), + } + ep_dicts.append(ep_dict) - if episode_id == 0: - # hack to initialize tensordict data structure to store episodes - td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}") - - td_data[idxtd : idxtd + len(ep_td)] = ep_td + assert isinstance(episode_id, int) + self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1) + assert len(self.data_ids_per_episode[episode_id]) == num_frames idx0 = idx1 - idxtd = idxtd + len(ep_td) - return TensorStorage(td_data.lock_()) + self.data_dict = {} + + keys = ep_dicts[0].keys() + for key in keys: + self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + + self.data_dict["index"] = torch.arange(0, total_frames, 1) + + +if __name__ == "__main__": + dataset = PushtDataset( + "pusht", + root=Path("data"), + delta_timestamps={ + "observation.image": [0, -1, -0.2, -0.1], + "observation.state": [0, -1, -0.2, -0.1], + "action": [-0.1, 0, 1, 2, 3], + }, + ) + dataset[10] diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index dc30e69e..4b2c68ad 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -1,75 +1,104 @@ import pickle import zipfile from pathlib import Path -from typing import Callable import torch -import torchrl import tqdm -from tensordict import TensorDict -from torchrl.data.replay_buffers.samplers import ( - Sampler, -) -from torchrl.data.replay_buffers.storages import TensorStorage -from torchrl.data.replay_buffers.writers import Writer -from lerobot.common.datasets.abstract import AbstractDataset +from lerobot.common.datasets.utils import load_data_with_delta_timestamps -def download(): - raise NotImplementedError() +def download(raw_dir): import gdown + raw_dir.mkdir(parents=True, exist_ok=True) url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" - download_path = "data.zip" - gdown.download(url, download_path, quiet=False) + zip_path = raw_dir / "data.zip" + gdown.download(url, str(zip_path), quiet=False) print("Extracting...") - with zipfile.ZipFile(download_path, "r") as zip_f: + with zipfile.ZipFile(str(zip_path), "r") as zip_f: for member in zip_f.namelist(): if member.startswith("data/xarm") and member.endswith(".pkl"): print(member) zip_f.extract(member=member) - Path(download_path).unlink() + zip_path.unlink() -class SimxarmDataset(AbstractDataset): +class SimxarmDataset(torch.utils.data.Dataset): available_datasets = [ "xarm_lift_medium", ] + fps = 15 + image_keys = ["observation.image"] def __init__( self, dataset_id: str, version: str | None = "v1.1", - batch_size: int | None = None, - *, - shuffle: bool = True, root: Path | None = None, - pin_memory: bool = False, - prefetch: int = None, - sampler: Sampler | None = None, - collate_fn: Callable | None = None, - writer: Writer | None = None, - transform: "torchrl.envs.Transform" = None, + transform: callable = None, + delta_timestamps: dict[list[float]] | None = None, ): - super().__init__( - dataset_id, - version, - batch_size, - shuffle=shuffle, - root=root, - pin_memory=pin_memory, - prefetch=prefetch, - sampler=sampler, - collate_fn=collate_fn, - writer=writer, - transform=transform, - ) + super().__init__() + self.dataset_id = dataset_id + self.version = version + self.root = root + self.transform = transform + self.delta_timestamps = delta_timestamps + + data_dir = self.root / f"{self.dataset_id}" + if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): + self.data_dict = torch.load(data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") + else: + self._download_and_preproc_obsolete() + data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") + + @property + def num_samples(self) -> int: + return len(self.data_dict["index"]) + + @property + def num_episodes(self) -> int: + return len(self.data_ids_per_episode) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + item = {} + + # get episode id and timestamp of the sampled frame + current_ts = self.data_dict["timestamp"][idx].item() + episode = self.data_dict["episode"][idx].item() + + for key in self.data_dict: + if self.delta_timestamps is not None and key in self.delta_timestamps: + data, is_pad = load_data_with_delta_timestamps( + self.data_dict, + self.data_ids_per_episode, + self.delta_timestamps, + key, + current_ts, + episode, + ) + item[key] = data + item[f"{key}_is_pad"] = is_pad + else: + item[key] = self.data_dict[key][idx] + + if self.transform is not None: + item = self.transform(item) + + return item def _download_and_preproc_obsolete(self): - # assert self.root is not None - # TODO(rcadene): finish download - # download() + assert self.root is not None + raw_dir = self.root / f"{self.dataset_id}_raw" + if not raw_dir.exists(): + download(raw_dir) dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl" print(f"Using offline dataset '{dataset_path}'") @@ -78,6 +107,9 @@ class SimxarmDataset(AbstractDataset): total_frames = dataset_dict["actions"].shape[0] + self.data_ids_per_episode = {} + ep_dicts = [] + idx0 = 0 idx1 = 0 episode_id = 0 @@ -91,37 +123,38 @@ class SimxarmDataset(AbstractDataset): image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) - next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1]) - next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1]) + action = torch.tensor(dataset_dict["actions"][idx0:idx1]) + # TODO(rcadene): concat the last "next_observations" to "observations" + # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1]) + # next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1]) next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) next_done = torch.tensor(dataset_dict["dones"][idx0:idx1]) - episode = TensorDict( - { - ("observation", "image"): image, - ("observation", "state"): state, - "action": torch.tensor(dataset_dict["actions"][idx0:idx1]), - "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), - "frame_id": torch.arange(0, num_frames, 1), - ("next", "observation", "image"): next_image, - ("next", "observation", "state"): next_state, - ("next", "reward"): next_reward, - ("next", "done"): next_done, - }, - batch_size=num_frames, - ) + ep_dict = { + "observation.image": image, + "observation.state": state, + "action": action, + "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / self.fps, + # "next.observation.image": next_image, + # "next.observation.state": next_state, + "next.reward": next_reward, + "next.done": next_done, + } + ep_dicts.append(ep_dict) - if episode_id == 0: - # hack to initialize tensordict data structure to store episodes - td_data = ( - episode[0] - .expand(total_frames) - .memmap_like(self.root / f"{self.dataset_id}" / "replay_buffer") - ) + assert isinstance(episode_id, int) + self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1) + assert len(self.data_ids_per_episode[episode_id]) == num_frames - td_data[idx0:idx1] = episode - - episode_id += 1 idx0 = idx1 + episode_id += 1 - return TensorStorage(td_data.lock_()) + self.data_dict = {} + + keys = ep_dicts[0].keys() + for key in keys: + self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + + self.data_dict["index"] = torch.arange(0, total_frames, 1) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 0ad43a65..c8840169 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -3,6 +3,7 @@ import zipfile from pathlib import Path import requests +import torch import tqdm @@ -28,3 +29,71 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool: return True else: return False + + +def euclidean_distance_matrix(mat0, mat1): + # Compute the square of the distance matrix + sq0 = torch.sum(mat0**2, dim=1, keepdim=True) + sq1 = torch.sum(mat1**2, dim=1, keepdim=True) + distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1) + + # Taking the square root to get the euclidean distance + distance = torch.sqrt(torch.clamp(distance_sq, min=0)) + return distance + + +def is_contiguously_true_or_false(bool_vector): + assert bool_vector.ndim == 1 + assert bool_vector.dtype == torch.bool + + # Compare each element with its neighbor to find changes + changes = bool_vector[1:] != bool_vector[:-1] + + # Count the number of changes + num_changes = changes.sum().item() + + # If there's more than one change, the list is not contiguous + return num_changes <= 1 + + # examples = [ + # ([True, False, True, False, False, False], False), + # ([True, True, True, False, False, False], True), + # ([False, False, False, False, False, False], True) + # ] + # for bool_list, expected in examples: + # result = is_contiguously_true_or_false(bool_list) + + +def load_data_with_delta_timestamps( + data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode +): + # get indices of the frames associated to the episode, and their timestamps + ep_data_ids = data_ids_per_episode[episode] + ep_timestamps = data_dict["timestamp"][ep_data_ids] + + # get timestamps used as query to retrieve data of previous/future frames + delta_ts = delta_timestamps[key] + query_ts = current_ts + torch.tensor(delta_ts) + + # compute distances between each query timestamp and all timestamps of all the frames belonging to the episode + dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None]) + min_, argmin_ = dist.min(1) + + # get the indices of the data that are closest to the query timestamps + data_ids = ep_data_ids[argmin_] + # closest_ts = ep_timestamps[argmin_] + + # get the data + data = data_dict[key][data_ids].clone() + + # TODO(rcadene): synchronize timestamps + interpolation if needed + + tol = 0.02 + is_pad = min_ > tol + + assert is_contiguously_true_or_false(is_pad), ( + "One or several timestamps unexpectedly violate the tolerance." + "This might be due to synchronization issues with timestamps during data collection." + ) + + return data, is_pad diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 855e073b..788af3cb 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,64 +1,40 @@ -from torchrl.envs import SerialEnv -from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv +import gymnasium as gym -def make_env(cfg, transform=None): +def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv: """ - Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying - environments. The env therefore returns batches.` + Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and + returns batched observation, reward, terminated, truncated of `num_parallel_envs` items. """ - - kwargs = { - "frame_skip": cfg.env.action_repeat, - "from_pixels": cfg.env.from_pixels, - "pixels_only": cfg.env.pixels_only, - "image_size": cfg.env.image_size, - "num_prev_obs": cfg.n_obs_steps - 1, - } + kwargs = {} if cfg.env.name == "simxarm": - from lerobot.common.envs.simxarm.env import SimxarmEnv - kwargs["task"] = cfg.env.task - clsfunc = SimxarmEnv elif cfg.env.name == "pusht": - from lerobot.common.envs.pusht.env import PushtEnv + import gym_pusht # noqa # assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range." - - clsfunc = PushtEnv + kwargs.update( + { + "obs_type": "pixels_agent_pos", + "render_action": False, + } + ) + env_fn = lambda: gym.make( # noqa: E731 + "gym_pusht/PushTPixels-v0", + render_mode="rgb_array", + max_episode_steps=cfg.env.episode_length, + **kwargs, + ) elif cfg.env.name == "aloha": - from lerobot.common.envs.aloha.env import AlohaEnv - kwargs["task"] = cfg.env.task - clsfunc = AlohaEnv else: raise ValueError(cfg.env.name) - def _make_env(seed): - nonlocal kwargs - kwargs["seed"] = seed - env = clsfunc(**kwargs) - - # limit rollout to max_steps - env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) - - if transform is not None: - # useful to add normalization - if isinstance(transform, Compose): - for tf in transform: - env.append_transform(tf.clone()) - elif isinstance(transform, Transform): - env.append_transform(transform.clone()) - else: - raise NotImplementedError() - - return env - - return SerialEnv( - cfg.rollout_batch_size, - create_env_fn=_make_env, - create_env_kwargs=[ - {"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) - ], - ) + if num_parallel_envs == 0: + # non-batched version of the env that returns an observation of shape (c) + env = env_fn() + else: + # batched version of the env that returns an observation of shape (b, c) + env = gym.vector.SyncVectorEnv([env_fn for _ in range(num_parallel_envs)]) + return env diff --git a/lerobot/common/envs/simxarm/env.py b/lerobot/common/envs/simxarm/env.py index b81bf499..8ce6b24c 100644 --- a/lerobot/common/envs/simxarm/env.py +++ b/lerobot/common/envs/simxarm/env.py @@ -55,7 +55,7 @@ class SimxarmEnv(AbstractEnv): if not _has_gym: raise ImportError("Cannot import gymnasium.") - import gymnasium + import gymnasium as gym from lerobot.common.envs.simxarm.simxarm import TASKS @@ -65,7 +65,7 @@ class SimxarmEnv(AbstractEnv): self._env = TASKS[self.task]["env"]() num_actions = len(TASKS[self.task]["action_space"]) - self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,)) + self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,)) self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32) if "w" not in TASKS[self.task]["action_space"]: self._action_padding[-1] = 1.0 diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 82f39b28..a0fe0eba 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -1,18 +1,20 @@ import copy import logging import time +from collections import deque import hydra import torch +from torch import nn -from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder +from lerobot.common.policies.utils import populate_queues from lerobot.common.utils import get_safe_torch_device -class DiffusionPolicy(AbstractPolicy): +class DiffusionPolicy(nn.Module): name = "diffusion" def __init__( @@ -38,8 +40,12 @@ class DiffusionPolicy(AbstractPolicy): # parameters passed to step **kwargs, ): - super().__init__(n_action_steps) + super().__init__() self.cfg = cfg + self.n_obs_steps = n_obs_steps + self.n_action_steps = n_action_steps + # queues are populated during rollout of the policy, they contain the n latest observations and actions + self._queues = None noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape) @@ -100,76 +106,58 @@ class DiffusionPolicy(AbstractPolicy): last_epoch=self.global_step - 1, ) + def reset(self): + """ + Clear observation and action queues. Should be called on `env.reset()` + """ + self._queues = { + "observation.image": deque(maxlen=self.n_obs_steps), + "observation.state": deque(maxlen=self.n_obs_steps), + "action": deque(maxlen=self.n_action_steps), + } + @torch.no_grad() - def select_actions(self, observation, step_count): + def select_action(self, batch, step): """ Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights. """ - # TODO(rcadene): remove unused step_count - del step_count + # TODO(rcadene): remove unused step + del step + assert "observation.image" in batch + assert "observation.state" in batch + assert len(batch) == 2 - obs_dict = { - "image": observation["image"], - "agent_pos": observation["state"], - } - if self.training: - out = self.diffusion.predict_action(obs_dict) - else: - out = self.ema_diffusion.predict_action(obs_dict) - action = out["action"] + self._queues = populate_queues(self._queues, batch) + + if len(self._queues["action"]) == 0: + # stack n latest observations from the queue + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} + + obs_dict = { + "image": batch["observation.image"], + "agent_pos": batch["observation.state"], + } + if self.training: + out = self.diffusion.predict_action(obs_dict) + else: + out = self.ema_diffusion.predict_action(obs_dict) + self._queues["action"].extend(out["action"].transpose(0, 1)) + + action = self._queues["action"].popleft() return action - def update(self, replay_buffer, step): + def forward(self, batch, step): start_time = time.time() self.diffusion.train() - num_slices = self.cfg.batch_size - batch_size = self.cfg.horizon * num_slices - - assert batch_size % self.cfg.horizon == 0 - assert batch_size % num_slices == 0 - - def process_batch(batch, horizon, num_slices): - # trajectory t = 64, horizon h = 16 - # (t h) ... -> t h ... - batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous() - - # |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16 - # |o|o| observations: 2 - # | |a|a|a|a|a|a|a|a| actions executed: 8 - # |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16 - # note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model - - image = batch["observation", "image"] - state = batch["observation", "state"] - action = batch["action"] - assert image.shape[1] == horizon - assert state.shape[1] == horizon - assert action.shape[1] == horizon - - if not (horizon == 16 and self.cfg.n_obs_steps == 2): - raise NotImplementedError() - - # keep first 2 observations of the slice corresponding to t=[-1,0] - image = image[:, : self.cfg.n_obs_steps] - state = state[:, : self.cfg.n_obs_steps] - - out = { - "obs": { - "image": image.to(self.device, non_blocking=True), - "agent_pos": state.to(self.device, non_blocking=True), - }, - "action": action.to(self.device, non_blocking=True), - } - return out - - batch = replay_buffer.sample(batch_size) - batch = process_batch(batch, self.cfg.horizon, num_slices) - data_s = time.time() - start_time - loss = self.diffusion.compute_loss(batch) + obs_dict = { + "image": batch["observation.image"], + "agent_pos": batch["observation.state"], + } + loss = self.diffusion.compute_loss(obs_dict) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( diff --git a/lerobot/common/transforms.py b/lerobot/common/transforms.py index 4832c91b..ec967614 100644 --- a/lerobot/common/transforms.py +++ b/lerobot/common/transforms.py @@ -1,53 +1,49 @@ -from typing import Sequence - import torch -from tensordict import TensorDictBase -from tensordict.nn import dispatch -from tensordict.utils import NestedKey -from torchrl.envs.transforms import ObservationTransform, Transform +from torchvision.transforms.v2 import Compose, Transform -class Prod(ObservationTransform): +def apply_inverse_transform(item, transform): + transforms = transform.transforms if isinstance(transform, Compose) else [transform] + for tf in transforms[::-1]: + if tf.invertible: + item = tf.inverse_transform(item) + else: + raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).") + return item + + +class Prod(Transform): invertible = True - def __init__(self, in_keys: Sequence[NestedKey], prod: float): + def __init__(self, in_keys: list[str], prod: float): super().__init__() self.in_keys = in_keys self.prod = prod self.original_dtypes = {} - def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase: - # _reset is called once when the environment reset to normalize the first observation - tensordict_reset = self._call(tensordict_reset) - return tensordict_reset - - @dispatch(source="in_keys", dest="out_keys") - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - return self._call(tensordict) - - def _call(self, td): + def forward(self, item): for key in self.in_keys: - if td.get(key, None) is None: + if key not in item: continue - self.original_dtypes[key] = td[key].dtype - td[key] = td[key].type(torch.float32) * self.prod - return td + self.original_dtypes[key] = item[key].dtype + item[key] = item[key].type(torch.float32) * self.prod + return item - def _inv_call(self, td: TensorDictBase) -> TensorDictBase: + def inverse_transform(self, item): for key in self.in_keys: - if td.get(key, None) is None: + if key not in item: continue - td[key] = (td[key] / self.prod).type(self.original_dtypes[key]) - return td + item[key] = (item[key] / self.prod).type(self.original_dtypes[key]) + return item - def transform_observation_spec(self, obs_spec): - for key in self.in_keys: - if obs_spec.get(key, None) is None: - continue - obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod - obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod - obs_spec[key].dtype = torch.float32 - return obs_spec + # def transform_observation_spec(self, obs_spec): + # for key in self.in_keys: + # if obs_spec.get(key, None) is None: + # continue + # obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod + # obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod + # obs_spec[key].dtype = torch.float32 + # return obs_spec class NormalizeTransform(Transform): @@ -55,65 +51,50 @@ class NormalizeTransform(Transform): def __init__( self, - stats: TensorDictBase, - in_keys: Sequence[NestedKey] = None, - out_keys: Sequence[NestedKey] | None = None, - in_keys_inv: Sequence[NestedKey] | None = None, - out_keys_inv: Sequence[NestedKey] | None = None, + stats: dict, + in_keys: list[str] = None, + out_keys: list[str] | None = None, + in_keys_inv: list[str] | None = None, + out_keys_inv: list[str] | None = None, mode="mean_std", ): - if out_keys is None: - out_keys = in_keys - if in_keys_inv is None: - in_keys_inv = out_keys - if out_keys_inv is None: - out_keys_inv = in_keys - super().__init__( - in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv - ) + super().__init__() + self.in_keys = in_keys + self.out_keys = in_keys if out_keys is None else out_keys + self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv + self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv self.stats = stats assert mode in ["mean_std", "min_max"] self.mode = mode - def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase: - # _reset is called once when the environment reset to normalize the first observation - tensordict_reset = self._call(tensordict_reset) - return tensordict_reset - - @dispatch(source="in_keys", dest="out_keys") - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - return self._call(tensordict) - - def _call(self, td: TensorDictBase) -> TensorDictBase: + def forward(self, item): for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False): - # TODO(rcadene): don't know how to do `inkey not in td` - if td.get(inkey, None) is None: + if inkey not in item: continue if self.mode == "mean_std": mean = self.stats[inkey]["mean"] std = self.stats[inkey]["std"] - td[outkey] = (td[inkey] - mean) / (std + 1e-8) + item[outkey] = (item[inkey] - mean) / (std + 1e-8) else: min = self.stats[inkey]["min"] max = self.stats[inkey]["max"] # normalize to [0,1] - td[outkey] = (td[inkey] - min) / (max - min) + item[outkey] = (item[inkey] - min) / (max - min) # normalize to [-1, 1] - td[outkey] = td[outkey] * 2 - 1 - return td + item[outkey] = item[outkey] * 2 - 1 + return item - def _inv_call(self, td: TensorDictBase) -> TensorDictBase: + def inverse_transform(self, item): for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False): - # TODO(rcadene): don't know how to do `inkey not in td` - if td.get(inkey, None) is None: + if inkey not in item: continue if self.mode == "mean_std": mean = self.stats[inkey]["mean"] std = self.stats[inkey]["std"] - td[outkey] = td[inkey] * std + mean + item[outkey] = item[inkey] * std + mean else: min = self.stats[inkey]["min"] max = self.stats[inkey]["max"] - td[outkey] = (td[inkey] + 1) / 2 - td[outkey] = td[outkey] * (max - min) + min - return td + item[outkey] = (item[inkey] + 1) / 2 + item[outkey] = item[outkey] * (max - min) + min + return item diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 216769d6..fe0f7bb2 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -36,111 +36,196 @@ from datetime import datetime as dt from pathlib import Path import einops +import gymnasium as gym +import hydra import imageio import numpy as np import torch -import tqdm from huggingface_hub import snapshot_download -from tensordict.nn import TensorDictModule -from torchrl.envs import EnvBase -from torchrl.envs.batched_envs import BatchedEnvBase -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.factory import make_env from lerobot.common.logger import log_output_dir -from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.factory import make_policy from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed +from lerobot.common.transforms import apply_inverse_transform def write_video(video_path, stacked_frames, fps): imageio.mimsave(video_path, stacked_frames, fps=fps) +def preprocess_observation(observation, transform=None): + # map to expected inputs for the policy + obs = { + "observation.image": torch.from_numpy(observation["pixels"]).float(), + "observation.state": torch.from_numpy(observation["agent_pos"]).float(), + } + # convert to (b c h w) torch format + obs["observation.image"] = einops.rearrange(obs["observation.image"], "b h w c -> b c h w") + + # apply same transforms as in training + if transform is not None: + for key in obs: + obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]]) + + return obs + + +def postprocess_action(action, transform=None): + action = action.to("cpu") + # action is a batch (num_env,action_dim) instead of an item (action_dim), + # we assume applying inverse transform on a batch works the same + action = apply_inverse_transform({"action": action}, transform)["action"].numpy() + assert ( + action.ndim == 2 + ), "we assume dimensions are respectively the number of parallel envs, action dimensions" + return action + + def eval_policy( - env: BatchedEnvBase, - policy: AbstractPolicy, - num_episodes: int = 10, - max_steps: int = 30, + env: gym.vector.VectorEnv, + policy, save_video: bool = False, video_dir: Path = None, + # TODO(rcadene): make it possible to overwrite fps? we should use env.fps fps: int = 15, return_first_video: bool = False, + transform: callable = None, ): if policy is not None: policy.eval() start = time.time() sum_rewards = [] max_rewards = [] - successes = [] + all_successes = [] seeds = [] threads = [] # for video saving threads episode_counter = 0 # for saving the correct number of videos + num_episodes = len(env.envs) + # TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than # needed as I'm currently taking a ceil. - for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))): - ep_frames = [] + ep_frames = [] - def maybe_render_frame(env: EnvBase, _): - if save_video or (return_first_video and i == 0): # noqa: B023 - ep_frames.append(env.render()) # noqa: B023 + def maybe_render_frame(env): + if save_video: # noqa: B023 + if return_first_video: + visu = env.envs[0].render() + visu = visu[None, ...] # add batch dim + else: + visu = np.stack([env.render() for env in env.envs]) + ep_frames.append(visu) # noqa: B023 - # Clear the policy's action queue before the start of a new rollout. - if policy is not None: - policy.clear_action_queue() + for _ in range(num_episodes): + seeds.append("TODO") - if env.is_closed: - env.start() # needed to be able to get the seeds the first time as BatchedEnvs are lazy - seeds.extend(env._next_seed) + if hasattr(policy, "reset"): + policy.reset() + else: + logging.warning( + f"Policy {policy} doesnt have a `reset` method. This find if the policy doesnt rely on an internal state during rollout." + ) + + # reset the environment + observation, info = env.reset(seed=cfg.seed) + maybe_render_frame(env) + + rewards = [] + successes = [] + dones = [] + + done = torch.tensor([False for _ in env.envs]) + step = 0 + do_rollout = True + while do_rollout: + # apply transform to normalize the observations + observation = preprocess_observation(observation, transform) + + # send observation to device/gpu + observation = {key: observation[key].to(cfg.device, non_blocking=True) for key in observation} + + # get the next action for the environment with torch.inference_mode(): - # TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all - # envs are done the first time. But we only use the first rollout. This is a waste of compute. - rollout = env.rollout( - max_steps=max_steps, - policy=policy, - auto_cast_to_device=True, - callback=maybe_render_frame, - break_when_any_done=env.batch_size[0] == 1, - ) - # Figure out where in each rollout sequence the first done condition was encountered (results after - # this won't be included). - # Note: this assumes that the shape of the done key is (batch_size, max_steps, 1). - # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. - rollout_steps = rollout["next", "done"].shape[1] - done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps) - mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1) - batch_sum_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "sum") - batch_max_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "max") - batch_success = einops.reduce((rollout["next", "success"] * mask), "b n 1 -> b", "any") - sum_rewards.extend(batch_sum_reward.tolist()) - max_rewards.extend(batch_max_reward.tolist()) - successes.extend(batch_success.tolist()) + action = policy.select_action(observation, step) - if save_video or (return_first_video and i == 0): - batch_stacked_frames = np.stack(ep_frames) # (t, b, *) - batch_stacked_frames = batch_stacked_frames.transpose( - 1, 0, *range(2, batch_stacked_frames.ndim) - ) # (b, t, *) + # apply inverse transform to unnormalize the action + action = postprocess_action(action, transform) - if save_video: - for stacked_frames, done_index in zip( - batch_stacked_frames, done_indices.flatten().tolist(), strict=False - ): - if episode_counter >= num_episodes: - continue - video_dir.mkdir(parents=True, exist_ok=True) - video_path = video_dir / f"eval_episode_{episode_counter}.mp4" - thread = threading.Thread( - target=write_video, - args=(str(video_path), stacked_frames[:done_index], fps), - ) - thread.start() - threads.append(thread) - episode_counter += 1 + # apply the next + observation, reward, terminated, truncated, info = env.step(action) + maybe_render_frame(env) - if return_first_video and i == 0: - first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2) + # TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?) + reward = torch.from_numpy(reward) + terminated = torch.from_numpy(terminated) + truncated = torch.from_numpy(truncated) + # environment is considered done (no more steps), when success state is reached (terminated is True), + # or time limit is reached (truncated is True), or it was previsouly done. + done = terminated | truncated | done + + if "final_info" in info: + # VectorEnv stores is_success into `info["final_info"][env_id]["is_success"]` instead of `info["is_success"]` + success = [ + env_info["is_success"] if env_info is not None else False for env_info in info["final_info"] + ] + else: + success = [False for _ in env.envs] + success = torch.tensor(success) + + rewards.append(reward) + dones.append(done) + successes.append(success) + + step += 1 + + if done.all(): + do_rollout = False + break + + rewards = torch.stack(rewards, dim=1) + successes = torch.stack(successes, dim=1) + dones = torch.stack(dones, dim=1) + + # Figure out where in each rollout sequence the first done condition was encountered (results after + # this won't be included). + # Note: this assumes that the shape of the done key is (batch_size, max_steps). + # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. + done_indices = torch.argmax(dones.to(int), axis=1) # (batch_size, rollout_steps) + expand_done_indices = done_indices[:, None].expand(-1, step) + expand_step_indices = torch.arange(step)[None, :].expand(num_episodes, -1) + mask = (expand_step_indices <= expand_done_indices).int() # (batch_size, rollout_steps) + batch_sum_reward = einops.reduce((rewards * mask), "b n -> b", "sum") + batch_max_reward = einops.reduce((rewards * mask), "b n -> b", "max") + batch_success = einops.reduce((successes * mask), "b n -> b", "any") + sum_rewards.extend(batch_sum_reward.tolist()) + max_rewards.extend(batch_max_reward.tolist()) + all_successes.extend(batch_success.tolist()) + + env.close() + + if save_video or return_first_video: + batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *) + + if save_video: + for stacked_frames, done_index in zip( + batch_stacked_frames, done_indices.flatten().tolist(), strict=False + ): + if episode_counter >= num_episodes: + continue + video_dir.mkdir(parents=True, exist_ok=True) + video_path = video_dir / f"eval_episode_{episode_counter}.mp4" + thread = threading.Thread( + target=write_video, + args=(str(video_path), stacked_frames[:done_index], fps), + ) + thread.start() + threads.append(thread) + episode_counter += 1 + + if return_first_video: + first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2) for thread in threads: thread.join() @@ -158,16 +243,16 @@ def eval_policy( zip( sum_rewards[:num_episodes], max_rewards[:num_episodes], - successes[:num_episodes], + all_successes[:num_episodes], seeds[:num_episodes], strict=True, ) ) ], "aggregated": { - "avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]), - "avg_max_reward": np.nanmean(max_rewards[:num_episodes]), - "pc_success": np.nanmean(successes[:num_episodes]) * 100, + "avg_sum_reward": float(np.nanmean(sum_rewards[:num_episodes])), + "avg_max_reward": float(np.nanmean(max_rewards[:num_episodes])), + "pc_success": float(np.nanmean(all_successes[:num_episodes]) * 100), "eval_s": time.time() - start, "eval_ep_s": (time.time() - start) / num_episodes, }, @@ -194,21 +279,13 @@ def eval(cfg: dict, out_dir=None, stats_path=None): logging.info("Making transforms.") # TODO(alexander-soare): Completely decouple datasets from evaluation. - offline_buffer = make_offline_buffer(cfg, stats_path=stats_path) + dataset = make_dataset(cfg, stats_path=stats_path) logging.info("Making environment.") - env = make_env(cfg, transform=offline_buffer.transform) + env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) - if cfg.policy.pretrained_model_path: - policy = make_policy(cfg) - policy = TensorDictModule( - policy, - in_keys=["observation", "step_count"], - out_keys=["action"], - ) - else: - # when policy is None, rollout a random policy - policy = None + # when policy is None, rollout a random policy + policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None info = eval_policy( env, @@ -216,8 +293,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None): save_video=True, video_dir=Path(out_dir) / "eval", fps=cfg.env.fps, - max_steps=cfg.env.episode_length, - num_episodes=cfg.eval_episodes, + # TODO(rcadene): what should we do with the transform? + transform=dataset.transform, ) print(info["aggregated"]) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 18c3715b..5e9cd361 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,14 +1,12 @@ import logging +from itertools import cycle from pathlib import Path import hydra import numpy as np import torch -from tensordict.nn import TensorDictModule -from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer -from torchrl.data.replay_buffers import PrioritizedSliceSampler -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy @@ -34,7 +32,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa train(cfg, out_dir=out_dir, job_name=job_name) -def log_train_info(logger, info, step, cfg, offline_buffer, is_offline): +def log_train_info(logger, info, step, cfg, dataset, is_offline): loss = info["loss"] grad_norm = info["grad_norm"] lr = info["lr"] @@ -44,9 +42,9 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline): # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. num_samples = (step + 1) * cfg.policy.batch_size - avg_samples_per_ep = offline_buffer.num_samples / offline_buffer.num_episodes + avg_samples_per_ep = dataset.num_samples / dataset.num_episodes num_episodes = num_samples / avg_samples_per_ep - num_epochs = num_samples / offline_buffer.num_samples + num_epochs = num_samples / dataset.num_samples log_items = [ f"step:{format_big_number(step)}", # number of samples seen during training @@ -73,7 +71,7 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline): logger.log_dict(info, step, mode="train") -def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline): +def log_eval_info(logger, info, step, cfg, dataset, is_offline): eval_s = info["eval_s"] avg_sum_reward = info["avg_sum_reward"] pc_success = info["pc_success"] @@ -81,9 +79,9 @@ def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline): # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. num_samples = (step + 1) * cfg.policy.batch_size - avg_samples_per_ep = offline_buffer.num_samples / offline_buffer.num_episodes + avg_samples_per_ep = dataset.num_samples / dataset.num_episodes num_episodes = num_samples / avg_samples_per_ep - num_epochs = num_samples / offline_buffer.num_samples + num_epochs = num_samples / dataset.num_samples log_items = [ f"step:{format_big_number(step)}", # number of samples seen during training @@ -124,30 +122,30 @@ def train(cfg: dict, out_dir=None, job_name=None): torch.backends.cuda.matmul.allow_tf32 = True set_global_seed(cfg.seed) - logging.info("make_offline_buffer") - offline_buffer = make_offline_buffer(cfg) + logging.info("make_dataset") + dataset = make_dataset(cfg) # TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy - if cfg.policy.balanced_sampling: - logging.info("make online_buffer") - num_traj_per_batch = cfg.policy.batch_size + # if cfg.policy.balanced_sampling: + # logging.info("make online_buffer") + # num_traj_per_batch = cfg.policy.batch_size - online_sampler = PrioritizedSliceSampler( - max_capacity=100_000, - alpha=cfg.policy.per_alpha, - beta=cfg.policy.per_beta, - num_slices=num_traj_per_batch, - strict_length=True, - ) + # online_sampler = PrioritizedSliceSampler( + # max_capacity=100_000, + # alpha=cfg.policy.per_alpha, + # beta=cfg.policy.per_beta, + # num_slices=num_traj_per_batch, + # strict_length=True, + # ) - online_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(100_000), - sampler=online_sampler, - transform=offline_buffer.transform, - ) + # online_buffer = TensorDictReplayBuffer( + # storage=LazyMemmapStorage(100_000), + # sampler=online_sampler, + # transform=dataset.transform, + # ) logging.info("make_env") - env = make_env(cfg, transform=offline_buffer.transform) + env = make_env(cfg) logging.info("make_policy") policy = make_policy(cfg) @@ -155,8 +153,6 @@ def train(cfg: dict, out_dir=None, job_name=None): num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - td_policy = TensorDictModule(policy, in_keys=["observation", "step_count"], out_keys=["action"]) - # log metrics to terminal and wandb logger = Logger(out_dir, job_name, cfg) @@ -165,8 +161,8 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.online_steps=}") logging.info(f"{cfg.env.action_repeat=}") - logging.info(f"{offline_buffer.num_samples=} ({format_big_number(offline_buffer.num_samples)})") - logging.info(f"{offline_buffer.num_episodes=}") + logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})") + logging.info(f"{dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") @@ -176,14 +172,15 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info(f"Eval policy at step {step}") eval_info, first_video = eval_policy( env, - td_policy, + policy, num_episodes=cfg.eval_episodes, max_steps=cfg.env.episode_length, return_first_video=True, video_dir=Path(out_dir) / "eval", save_video=True, + transform=dataset.transform, ) - log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_buffer, is_offline) + log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline) if cfg.wandb.enable: logger.log_video(first_video, step, mode="eval") logging.info("Resume training") @@ -196,14 +193,29 @@ def train(cfg: dict, out_dir=None, job_name=None): step = 0 # number of policy update (forward + backward + optim) is_offline = True + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=cfg.policy.batch_size, + shuffle=True, + pin_memory=cfg.device != "cpu", + drop_last=True, + ) + dl_iter = cycle(dataloader) for offline_step in range(cfg.offline_steps): if offline_step == 0: logging.info("Start offline training on a fixed dataset") - # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? policy.train() - train_info = policy.update(offline_buffer, step) + batch = next(dl_iter) + + for key in batch: + batch[key] = batch[key].to(cfg.device, non_blocking=True) + + train_info = policy.update(batch, step) + + # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: - log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) + log_train_info(logger, train_info, step, cfg, dataset, is_offline) # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in # step + 1. @@ -211,7 +223,7 @@ def train(cfg: dict, out_dir=None, job_name=None): step += 1 - demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None + demo_buffer = dataset if cfg.policy.balanced_sampling else None online_step = 0 is_offline = False for env_step in range(cfg.online_steps): @@ -221,7 +233,7 @@ def train(cfg: dict, out_dir=None, job_name=None): with torch.no_grad(): rollout = env.rollout( max_steps=cfg.env.episode_length, - policy=td_policy, + policy=policy, auto_cast_to_device=True, ) @@ -242,7 +254,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # set same episode index for all time steps contained in this rollout rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) - online_buffer.extend(rollout) + # online_buffer.extend(rollout) ep_sum_reward = rollout["next", "reward"].sum() ep_max_reward = rollout["next", "reward"].max() @@ -257,13 +269,13 @@ def train(cfg: dict, out_dir=None, job_name=None): for _ in range(cfg.policy.utd): train_info = policy.update( - online_buffer, + # online_buffer, step, demo_buffer=demo_buffer, ) if step % cfg.log_freq == 0: train_info.update(rollout_info) - log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) + log_train_info(logger, train_info, step, cfg, dataset, is_offline) # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass # in step + 1. diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 3dd7cdfa..93315e90 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -10,7 +10,7 @@ from torchrl.data.replay_buffers import ( SamplerWithoutReplacement, ) -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.logger import log_output_dir from lerobot.common.utils import init_logging @@ -44,8 +44,8 @@ def visualize_dataset(cfg: dict, out_dir=None): shuffle=False, ) - logging.info("make_offline_buffer") - offline_buffer = make_offline_buffer( + logging.info("make_dataset") + dataset = make_dataset( cfg, overwrite_sampler=sampler, # remove all transformations such as rescale images from [0,255] to [0,1] or normalization @@ -55,12 +55,12 @@ def visualize_dataset(cfg: dict, out_dir=None): ) logging.info("Start rendering episodes from offline buffer") - video_paths = render_dataset(offline_buffer, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps) + video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps) for video_path in video_paths: logging.info(video_path) -def render_dataset(offline_buffer, out_dir, max_num_samples, fps): +def render_dataset(dataset, out_dir, max_num_samples, fps): out_dir = Path(out_dir) video_paths = [] threads = [] @@ -69,17 +69,17 @@ def render_dataset(offline_buffer, out_dir, max_num_samples, fps): logging.info(f"Visualizing episode {current_ep_idx}") for i in range(max_num_samples): # TODO(rcadene): make it work with bsize > 1 - ep_td = offline_buffer.sample(1) + ep_td = dataset.sample(1) ep_idx = ep_td["episode"][FIRST_FRAME].item() - # TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames - num_frames_left = offline_buffer._sampler._sample_list.numel() + # TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames + num_frames_left = dataset._sampler._sample_list.numel() episode_is_done = ep_idx != current_ep_idx if episode_is_done: logging.info(f"Rendering episode {current_ep_idx}") - for im_key in offline_buffer.image_keys: + for im_key in dataset.image_keys: if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1): # when first frame of episode, initialize frames dict if im_key not in frames: @@ -93,7 +93,7 @@ def render_dataset(offline_buffer, out_dir, max_num_samples, fps): frames[im_key].append(ep_td["next"][im_key]) out_dir.mkdir(parents=True, exist_ok=True) - if len(offline_buffer.image_keys) > 1: + if len(dataset.image_keys) > 1: camera = im_key[-1] video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4" else: diff --git a/poetry.lock b/poetry.lock index 72397001..8fb6b7a7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -879,6 +879,29 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.62.1)"] +[[package]] +name = "gym-pusht" +version = "0.1.0" +description = "PushT environment for LeRobot" +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +gymnasium = "^0.29.1" +opencv-python = "^4.9.0.80" +pygame = "^2.5.2" +pymunk = "^6.6.0" +scikit-image = "^0.22.0" +shapely = "^2.0.3" + +[package.source] +type = "git" +url = "git@github.com:huggingface/gym-pusht.git" +reference = "HEAD" +resolved_reference = "d7e1a39a31b1368741e9674791007d7cccf046a3" + [[package]] name = "gymnasium" version = "0.29.1" @@ -3586,7 +3609,10 @@ files = [ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +[extras] +pusht = ["gym_pusht"] + [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "174c7d42f8039eedd2c447a4e6cae5169782cbd94346b5606572a0010194ca05" +content-hash = "3eee17e4bf2b7a570f41ef9c400ec5a24a3113f62a13162229cf43504ca0d005" diff --git a/pyproject.toml b/pyproject.toml index 972c1b61..d0fc7c0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,10 @@ robomimic = "0.2.0" gymnasium-robotics = "^1.2.4" gymnasium = "^0.29.1" cmake = "^3.29.0.1" +gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true} +[tool.poetry.extras] +pusht = ["gym_pusht"] [tool.poetry.group.dev.dependencies] pre-commit = "^3.6.2" diff --git a/tests/test_datasets.py b/tests/test_datasets.py index df41b03f..f7f80a42 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -6,6 +6,8 @@ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.utils import init_hydra_config +import logging +from lerobot.common.datasets.factory import make_dataset from .utils import DEVICE, DEFAULT_CONFIG_PATH @@ -26,14 +28,29 @@ def test_factory(env_name, dataset_id): DEFAULT_CONFIG_PATH, overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"] ) - offline_buffer = make_offline_buffer(cfg) - for key in offline_buffer.image_keys: - img = offline_buffer[0].get(key) + dataset = make_dataset(cfg) + + item = dataset[0] + + assert "action" in item + assert "episode" in item + assert "frame_id" in item + assert "timestamp" in item + assert "next.done" in item + # TODO(rcadene): should we rename it agent_pos? + assert "observation.state" in item + for key in dataset.image_keys: + img = item.get(key) assert img.dtype == torch.float32 # TODO(rcadene): we assume for now that image normalization takes place in the model assert img.max() <= 1.0 assert img.min() >= 0.0 + if "next.reward" not in item: + logging.warning(f'Missing "next.reward" key in dataset {dataset}.') + if "next.done" not in item: + logging.warning(f'Missing "next.done" key in dataset {dataset}.') + def test_compute_stats(): """Check that the statistics are computed correctly according to the stats_patterns property. diff --git a/tests/test_envs.py b/tests/test_envs.py index eb3746db..0c56f4fc 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -2,7 +2,7 @@ import pytest from tensordict import TensorDict import torch from torchrl.envs.utils import check_env_specs, step_mdp -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.aloha.env import AlohaEnv from lerobot.common.envs.factory import make_env @@ -116,15 +116,15 @@ def test_factory(env_name): overrides=[f"env={env_name}", f"device={DEVICE}"], ) - offline_buffer = make_offline_buffer(cfg) + dataset = make_dataset(cfg) env = make_env(cfg) - for key in offline_buffer.image_keys: + for key in dataset.image_keys: assert env.reset().get(key).dtype == torch.uint8 check_env_specs(env) - env = make_env(cfg, transform=offline_buffer.transform) - for key in offline_buffer.image_keys: + env = make_env(cfg, transform=dataset.transform) + for key in dataset.image_keys: img = env.reset().get(key) assert img.dtype == torch.float32 # TODO(rcadene): we assume for now that image normalization takes place in the model diff --git a/tests/test_policies.py b/tests/test_policies.py index 5d6b46d0..a46c6025 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -7,7 +7,7 @@ from torchrl.envs import EnvBase from lerobot.common.policies.factory import make_policy from lerobot.common.envs.factory import make_env -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.utils import init_hydra_config from .utils import DEVICE, DEFAULT_CONFIG_PATH @@ -45,13 +45,13 @@ def test_concrete_policy(env_name, policy_name, extra_overrides): # Check that we can make the policy object. policy = make_policy(cfg) # Check that we run select_actions and get the appropriate output. - offline_buffer = make_offline_buffer(cfg) - env = make_env(cfg, transform=offline_buffer.transform) + dataset = make_dataset(cfg) + env = make_env(cfg, transform=dataset.transform) if env_name != "aloha": # TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError: # seq_length as a list is not supported for now. - policy.update(offline_buffer, torch.tensor(0, device=DEVICE)) + policy.update(dataset, torch.tensor(0, device=DEVICE)) action = policy( env.observation_spec.rand()["observation"].to(DEVICE),