From 1cdfbc8b52fc7a14d148db74acc29cd02f087982 Mon Sep 17 00:00:00 2001 From: Cadene Date: Sun, 31 Mar 2024 15:05:25 +0000 Subject: [PATCH 1/3] WIP WIP WIP train.py works, loss going down WIP eval.py Fix WIP (eval running, TODO: verify results reproduced) Eval works! (testing reproducibility) WIP pretrained model pusht reproduces same results as torchrl pretrained model pusht reproduces same results as torchrl Remove AbstractPolicy, Move all queues in select_action WIP test_datasets passed (TODO: re-enable NormalizeTransform) --- lerobot/common/datasets/aloha.py | 168 ++++++++------ lerobot/common/datasets/factory.py | 134 ++++------- lerobot/common/datasets/pusht.py | 165 ++++++++----- lerobot/common/datasets/simxarm.py | 169 ++++++++------ lerobot/common/datasets/utils.py | 69 ++++++ lerobot/common/envs/factory.py | 74 ++---- lerobot/common/envs/simxarm/env.py | 4 +- lerobot/common/policies/diffusion/policy.py | 108 ++++----- lerobot/common/transforms.py | 123 +++++----- lerobot/scripts/eval.py | 245 +++++++++++++------- lerobot/scripts/train.py | 96 ++++---- lerobot/scripts/visualize_dataset.py | 20 +- poetry.lock | 28 ++- pyproject.toml | 3 + tests/test_datasets.py | 23 +- tests/test_envs.py | 10 +- tests/test_policies.py | 8 +- 17 files changed, 826 insertions(+), 621 deletions(-) diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 031c2cd3..2744f595 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -1,26 +1,13 @@ import logging from pathlib import Path -from typing import Callable import einops import gdown import h5py import torch -import torchrl import tqdm -from tensordict import TensorDict -from torchrl.data.replay_buffers.samplers import Sampler -from torchrl.data.replay_buffers.storages import TensorStorage -from torchrl.data.replay_buffers.writers import Writer -from lerobot.common.datasets.abstract import AbstractDataset - -DATASET_IDS = [ - "aloha_sim_insertion_human", - "aloha_sim_insertion_scripted", - "aloha_sim_transfer_cube_human", - "aloha_sim_transfer_cube_scripted", -] +from lerobot.common.datasets.utils import load_data_with_delta_timestamps FOLDER_URLS = { "aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF", @@ -66,7 +53,6 @@ CAMERAS = { def download(data_dir, dataset_id): - assert dataset_id in DATASET_IDS assert dataset_id in FOLDER_URLS assert dataset_id in EP48_URLS assert dataset_id in EP49_URLS @@ -80,51 +66,78 @@ def download(data_dir, dataset_id): gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True) -class AlohaDataset(AbstractDataset): - available_datasets = DATASET_IDS +class AlohaDataset(torch.utils.data.Dataset): + available_datasets = [ + "aloha_sim_insertion_human", + "aloha_sim_insertion_scripted", + "aloha_sim_transfer_cube_human", + "aloha_sim_transfer_cube_scripted", + ] + fps = 50 + image_keys = ["observation.images.top"] def __init__( self, dataset_id: str, version: str | None = "v1.2", - batch_size: int | None = None, - *, - shuffle: bool = True, root: Path | None = None, - pin_memory: bool = False, - prefetch: int = None, - sampler: Sampler | None = None, - collate_fn: Callable | None = None, - writer: Writer | None = None, - transform: "torchrl.envs.Transform" = None, + transform: callable = None, + delta_timestamps: dict[list[float]] | None = None, ): - super().__init__( - dataset_id, - version, - batch_size, - shuffle=shuffle, - root=root, - pin_memory=pin_memory, - prefetch=prefetch, - sampler=sampler, - collate_fn=collate_fn, - writer=writer, - transform=transform, - ) + super().__init__() + self.dataset_id = dataset_id + self.version = version + self.root = root + self.transform = transform + self.delta_timestamps = delta_timestamps + + data_dir = self.root / f"{self.dataset_id}" + if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): + self.data_dict = torch.load(data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") + else: + self._download_and_preproc_obsolete() + data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") @property - def stats_patterns(self) -> dict: - d = { - ("observation", "state"): "b c -> c", - ("action",): "b c -> c", - } - for cam in CAMERAS[self.dataset_id]: - d[("observation", "image", cam)] = "b c h w -> c 1 1" - return d + def num_samples(self) -> int: + return len(self.data_dict["index"]) @property - def image_keys(self) -> list: - return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]] + def num_episodes(self) -> int: + return len(self.data_ids_per_episode) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + item = {} + + # get episode id and timestamp of the sampled frame + current_ts = self.data_dict["timestamp"][idx].item() + episode = self.data_dict["episode"][idx].item() + + for key in self.data_dict: + if self.delta_timestamps is not None and key in self.delta_timestamps: + data, is_pad = load_data_with_delta_timestamps( + self.data_dict, + self.data_ids_per_episode, + self.delta_timestamps, + key, + current_ts, + episode, + ) + item[key] = data + item[f"{key}_is_pad"] = is_pad + else: + item[key] = self.data_dict[key][idx] + + if self.transform is not None: + item = self.transform(item) + + return item def _download_and_preproc_obsolete(self): assert self.root is not None @@ -132,54 +145,55 @@ class AlohaDataset(AbstractDataset): if not raw_dir.is_dir(): download(raw_dir, self.dataset_id) - total_num_frames = 0 + total_frames = 0 logging.info("Compute total number of frames to initialize offline buffer") for ep_id in range(NUM_EPISODES[self.dataset_id]): ep_path = raw_dir / f"episode_{ep_id}.hdf5" with h5py.File(ep_path, "r") as ep: - total_num_frames += ep["/action"].shape[0] - 1 - logging.info(f"{total_num_frames=}") + total_frames += ep["/action"].shape[0] - 1 + logging.info(f"{total_frames=}") + + self.data_ids_per_episode = {} + ep_dicts = [] logging.info("Initialize and feed offline buffer") - idxtd = 0 for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])): ep_path = raw_dir / f"episode_{ep_id}.hdf5" with h5py.File(ep_path, "r") as ep: - ep_num_frames = ep["/action"].shape[0] + num_frames = ep["/action"].shape[0] # last step of demonstration is considered done - done = torch.zeros(ep_num_frames, 1, dtype=torch.bool) + done = torch.zeros(num_frames, 1, dtype=torch.bool) done[-1] = True state = torch.from_numpy(ep["/observations/qpos"][:]) action = torch.from_numpy(ep["/action"][:]) - ep_td = TensorDict( - { - ("observation", "state"): state[:-1], - "action": action[:-1], - "episode": torch.tensor([ep_id] * (ep_num_frames - 1)), - "frame_id": torch.arange(0, ep_num_frames - 1, 1), - ("next", "observation", "state"): state[1:], - # TODO: compute reward and success - # ("next", "reward"): reward[1:], - ("next", "done"): done[1:], - # ("next", "success"): success[1:], - }, - batch_size=ep_num_frames - 1, - ) + ep_dict = { + "observation.state": state, + "action": action, + "episode": torch.tensor([ep_id] * num_frames), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / self.fps, + # "next.observation.state": state, + # TODO(rcadene): compute reward and success + # "next.reward": reward[1:], + "next.done": done[1:], + # "next.success": success[1:], + } for cam in CAMERAS[self.dataset_id]: image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) image = einops.rearrange(image, "b h w c -> b c h w").contiguous() - ep_td["observation", "image", cam] = image[:-1] - ep_td["next", "observation", "image", cam] = image[1:] + ep_dict[f"observation.images.{cam}"] = image[:-1] + # ep_dict[f"next.observation.images.{cam}"] = image[1:] - if ep_id == 0: - # hack to initialize tensordict data structure to store episodes - td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}") + ep_dicts.append(ep_dict) - td_data[idxtd : idxtd + len(ep_td)] = ep_td - idxtd = idxtd + len(ep_td) + self.data_dict = {} - return TensorStorage(td_data.lock_()) + keys = ep_dicts[0].keys() + for key in keys: + self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + + self.data_dict["index"] = torch.arange(0, total_frames, 1) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 04077034..94ac8ca4 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,11 +1,10 @@ -import logging import os from pathlib import Path import torch -from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler +from torchvision.transforms import v2 -from lerobot.common.transforms import NormalizeTransform, Prod +from lerobot.common.transforms import Prod # DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and # we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data` @@ -13,57 +12,12 @@ from lerobot.common.transforms import NormalizeTransform, Prod DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None -def make_offline_buffer( +def make_dataset( cfg, - overwrite_sampler=None, # set normalize=False to remove all transformations and keep images unnormalized in [0,255] normalize=True, - overwrite_batch_size=None, - overwrite_prefetch=None, stats_path=None, ): - if cfg.policy.balanced_sampling: - assert cfg.online_steps > 0 - batch_size = None - pin_memory = False - prefetch = None - else: - assert cfg.online_steps == 0 - num_slices = cfg.policy.batch_size - batch_size = cfg.policy.horizon * num_slices - pin_memory = cfg.device == "cuda" - prefetch = cfg.prefetch - - if overwrite_batch_size is not None: - batch_size = overwrite_batch_size - - if overwrite_prefetch is not None: - prefetch = overwrite_prefetch - - if overwrite_sampler is None: - # TODO(rcadene): move batch_size outside - num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon - # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. - # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. - - if cfg.offline_prioritized_sampler: - logging.info("use prioritized sampler for offline dataset") - sampler = PrioritizedSliceSampler( - max_capacity=100_000, - alpha=cfg.policy.per_alpha, - beta=cfg.policy.per_beta, - num_slices=num_traj_per_batch, - strict_length=False, - ) - else: - logging.info("use simple sampler for offline dataset") - sampler = SliceSampler( - num_slices=num_traj_per_batch, - strict_length=False, - ) - else: - sampler = overwrite_sampler - if cfg.env.name == "simxarm": from lerobot.common.datasets.simxarm import SimxarmDataset @@ -81,56 +35,56 @@ def make_offline_buffer( else: raise ValueError(cfg.env.name) - offline_buffer = clsfunc( - dataset_id=cfg.dataset_id, - sampler=sampler, - batch_size=batch_size, - root=DATA_DIR, - pin_memory=pin_memory, - prefetch=prefetch if isinstance(prefetch, int) else None, - ) - - if cfg.policy.name == "tdmpc": - img_keys = [] - for key in offline_buffer.image_keys: - img_keys.append(("next", *key)) - img_keys += offline_buffer.image_keys - else: - img_keys = offline_buffer.image_keys - + transforms = None if normalize: - transforms = [Prod(in_keys=img_keys, prod=1 / 255)] - # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, # min_max_from_spec - stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path) + # stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path) - # we only normalize the state and action, since the images are usually normalized inside the model for - # now (except for tdmpc: see the following) - in_keys = [("observation", "state"), ("action")] - - if cfg.policy.name == "tdmpc": - # TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now - in_keys += img_keys - # TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now. - in_keys += [("next", *key) for key in img_keys] - in_keys.append(("next", "observation", "state")) + stats = {} if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this - stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) - stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) - stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) - stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) + stats["observation.state"] = {} + stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) + stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) + stats["action"] = {} + stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) + stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std - normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" - transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) + # normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" - offline_buffer.set_transform(transforms) + transforms = v2.Compose( + [ + # TODO(rcadene): we need to do something about image_keys + Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), + # NormalizeTransform( + # stats, + # in_keys=[ + # "observation.state", + # "action", + # ], + # mode=normalization_mode, + # ), + ] + ) - if not overwrite_sampler: - index = torch.arange(0, offline_buffer.num_samples, 1) - sampler.extend(index) + if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": + # TODO(rcadene): implement delta_timestamps in config + delta_timestamps = { + "observation.image": [-0.1, 0], + "observation.state": [-0.1, 0], + "action": [-0.1] + [i / clsfunc.fps for i in range(15)], + } + else: + delta_timestamps = None - return offline_buffer + dataset = clsfunc( + dataset_id=cfg.dataset_id, + root=DATA_DIR, + delta_timestamps=delta_timestamps, + transform=transforms, + ) + + return dataset diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 624fb140..3de70b1f 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -1,20 +1,13 @@ from pathlib import Path -from typing import Callable import einops import numpy as np import pygame import pymunk import torch -import torchrl import tqdm -from tensordict import TensorDict -from torchrl.data.replay_buffers.samplers import Sampler -from torchrl.data.replay_buffers.storages import TensorStorage -from torchrl.data.replay_buffers.writers import Writer -from lerobot.common.datasets.abstract import AbstractDataset -from lerobot.common.datasets.utils import download_and_extract_zip +from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer @@ -83,37 +76,82 @@ def add_tee( return body -class PushtDataset(AbstractDataset): +class PushtDataset(torch.utils.data.Dataset): + """ + + Arguments + ---------- + delta_timestamps : dict[list[float]] | None, optional + Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image) + If `None`, no shift is applied to current timestamp and the data from the current frame is loaded. + """ + available_datasets = ["pusht"] + fps = 10 + image_keys = ["observation.image"] def __init__( self, dataset_id: str, version: str | None = "v1.2", - batch_size: int | None = None, - *, - shuffle: bool = True, root: Path | None = None, - pin_memory: bool = False, - prefetch: int = None, - sampler: Sampler | None = None, - collate_fn: Callable | None = None, - writer: Writer | None = None, - transform: "torchrl.envs.Transform" = None, + transform: callable = None, + delta_timestamps: dict[list[float]] | None = None, ): - super().__init__( - dataset_id, - version, - batch_size, - shuffle=shuffle, - root=root, - pin_memory=pin_memory, - prefetch=prefetch, - sampler=sampler, - collate_fn=collate_fn, - writer=writer, - transform=transform, - ) + super().__init__() + self.dataset_id = dataset_id + self.version = version + self.root = root + self.transform = transform + self.delta_timestamps = delta_timestamps + + data_dir = self.root / f"{self.dataset_id}" + if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): + self.data_dict = torch.load(data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") + else: + self._download_and_preproc_obsolete() + data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") + + @property + def num_samples(self) -> int: + return len(self.data_dict["index"]) + + @property + def num_episodes(self) -> int: + return len(self.data_ids_per_episode) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + item = {} + + # get episode id and timestamp of the sampled frame + current_ts = self.data_dict["timestamp"][idx].item() + episode = self.data_dict["episode"][idx].item() + + for key in self.data_dict: + if self.delta_timestamps is not None and key in self.delta_timestamps: + data, is_pad = load_data_with_delta_timestamps( + self.data_dict, + self.data_ids_per_episode, + self.delta_timestamps, + key, + current_ts, + episode, + ) + item[key] = data + item[f"{key}_is_pad"] = is_pad + else: + item[key] = self.data_dict[key][idx] + + if self.transform is not None: + item = self.transform(item) + + return item def _download_and_preproc_obsolete(self): assert self.root is not None @@ -147,8 +185,10 @@ class PushtDataset(AbstractDataset): states = torch.from_numpy(dataset_dict["state"]) actions = torch.from_numpy(dataset_dict["action"]) + self.data_ids_per_episode = {} + ep_dicts = [] + idx0 = 0 - idxtd = 0 for episode_id in tqdm.tqdm(range(num_episodes)): idx1 = dataset_dict.meta["episode_ends"][episode_id] # to create test artifact @@ -194,30 +234,45 @@ class PushtDataset(AbstractDataset): # last step of demonstration is considered done done[-1] = True - ep_td = TensorDict( - { - ("observation", "image"): image[:-1], - ("observation", "state"): agent_pos[:-1], - "action": actions[idx0:idx1][:-1], - "episode": episode_ids[idx0:idx1][:-1], - "frame_id": torch.arange(0, num_frames - 1, 1), - ("next", "observation", "image"): image[1:], - ("next", "observation", "state"): agent_pos[1:], - # TODO: verify that reward and done are aligned with image and agent_pos - ("next", "reward"): reward[1:], - ("next", "done"): done[1:], - ("next", "success"): success[1:], - }, - batch_size=num_frames - 1, - ) + ep_dict = { + "observation.image": image, + "observation.state": agent_pos, + "action": actions[idx0:idx1], + "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / self.fps, + # "next.observation.image": image[1:], + # "next.observation.state": agent_pos[1:], + # TODO(rcadene): verify that reward and done are aligned with image and agent_pos + "next.reward": torch.cat([reward[1:], reward[[-1]]]), + "next.done": torch.cat([done[1:], done[[-1]]]), + "next.success": torch.cat([success[1:], success[[-1]]]), + } + ep_dicts.append(ep_dict) - if episode_id == 0: - # hack to initialize tensordict data structure to store episodes - td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}") - - td_data[idxtd : idxtd + len(ep_td)] = ep_td + assert isinstance(episode_id, int) + self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1) + assert len(self.data_ids_per_episode[episode_id]) == num_frames idx0 = idx1 - idxtd = idxtd + len(ep_td) - return TensorStorage(td_data.lock_()) + self.data_dict = {} + + keys = ep_dicts[0].keys() + for key in keys: + self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + + self.data_dict["index"] = torch.arange(0, total_frames, 1) + + +if __name__ == "__main__": + dataset = PushtDataset( + "pusht", + root=Path("data"), + delta_timestamps={ + "observation.image": [0, -1, -0.2, -0.1], + "observation.state": [0, -1, -0.2, -0.1], + "action": [-0.1, 0, 1, 2, 3], + }, + ) + dataset[10] diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index dc30e69e..4b2c68ad 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -1,75 +1,104 @@ import pickle import zipfile from pathlib import Path -from typing import Callable import torch -import torchrl import tqdm -from tensordict import TensorDict -from torchrl.data.replay_buffers.samplers import ( - Sampler, -) -from torchrl.data.replay_buffers.storages import TensorStorage -from torchrl.data.replay_buffers.writers import Writer -from lerobot.common.datasets.abstract import AbstractDataset +from lerobot.common.datasets.utils import load_data_with_delta_timestamps -def download(): - raise NotImplementedError() +def download(raw_dir): import gdown + raw_dir.mkdir(parents=True, exist_ok=True) url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" - download_path = "data.zip" - gdown.download(url, download_path, quiet=False) + zip_path = raw_dir / "data.zip" + gdown.download(url, str(zip_path), quiet=False) print("Extracting...") - with zipfile.ZipFile(download_path, "r") as zip_f: + with zipfile.ZipFile(str(zip_path), "r") as zip_f: for member in zip_f.namelist(): if member.startswith("data/xarm") and member.endswith(".pkl"): print(member) zip_f.extract(member=member) - Path(download_path).unlink() + zip_path.unlink() -class SimxarmDataset(AbstractDataset): +class SimxarmDataset(torch.utils.data.Dataset): available_datasets = [ "xarm_lift_medium", ] + fps = 15 + image_keys = ["observation.image"] def __init__( self, dataset_id: str, version: str | None = "v1.1", - batch_size: int | None = None, - *, - shuffle: bool = True, root: Path | None = None, - pin_memory: bool = False, - prefetch: int = None, - sampler: Sampler | None = None, - collate_fn: Callable | None = None, - writer: Writer | None = None, - transform: "torchrl.envs.Transform" = None, + transform: callable = None, + delta_timestamps: dict[list[float]] | None = None, ): - super().__init__( - dataset_id, - version, - batch_size, - shuffle=shuffle, - root=root, - pin_memory=pin_memory, - prefetch=prefetch, - sampler=sampler, - collate_fn=collate_fn, - writer=writer, - transform=transform, - ) + super().__init__() + self.dataset_id = dataset_id + self.version = version + self.root = root + self.transform = transform + self.delta_timestamps = delta_timestamps + + data_dir = self.root / f"{self.dataset_id}" + if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): + self.data_dict = torch.load(data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") + else: + self._download_and_preproc_obsolete() + data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") + + @property + def num_samples(self) -> int: + return len(self.data_dict["index"]) + + @property + def num_episodes(self) -> int: + return len(self.data_ids_per_episode) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + item = {} + + # get episode id and timestamp of the sampled frame + current_ts = self.data_dict["timestamp"][idx].item() + episode = self.data_dict["episode"][idx].item() + + for key in self.data_dict: + if self.delta_timestamps is not None and key in self.delta_timestamps: + data, is_pad = load_data_with_delta_timestamps( + self.data_dict, + self.data_ids_per_episode, + self.delta_timestamps, + key, + current_ts, + episode, + ) + item[key] = data + item[f"{key}_is_pad"] = is_pad + else: + item[key] = self.data_dict[key][idx] + + if self.transform is not None: + item = self.transform(item) + + return item def _download_and_preproc_obsolete(self): - # assert self.root is not None - # TODO(rcadene): finish download - # download() + assert self.root is not None + raw_dir = self.root / f"{self.dataset_id}_raw" + if not raw_dir.exists(): + download(raw_dir) dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl" print(f"Using offline dataset '{dataset_path}'") @@ -78,6 +107,9 @@ class SimxarmDataset(AbstractDataset): total_frames = dataset_dict["actions"].shape[0] + self.data_ids_per_episode = {} + ep_dicts = [] + idx0 = 0 idx1 = 0 episode_id = 0 @@ -91,37 +123,38 @@ class SimxarmDataset(AbstractDataset): image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) - next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1]) - next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1]) + action = torch.tensor(dataset_dict["actions"][idx0:idx1]) + # TODO(rcadene): concat the last "next_observations" to "observations" + # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1]) + # next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1]) next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) next_done = torch.tensor(dataset_dict["dones"][idx0:idx1]) - episode = TensorDict( - { - ("observation", "image"): image, - ("observation", "state"): state, - "action": torch.tensor(dataset_dict["actions"][idx0:idx1]), - "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), - "frame_id": torch.arange(0, num_frames, 1), - ("next", "observation", "image"): next_image, - ("next", "observation", "state"): next_state, - ("next", "reward"): next_reward, - ("next", "done"): next_done, - }, - batch_size=num_frames, - ) + ep_dict = { + "observation.image": image, + "observation.state": state, + "action": action, + "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / self.fps, + # "next.observation.image": next_image, + # "next.observation.state": next_state, + "next.reward": next_reward, + "next.done": next_done, + } + ep_dicts.append(ep_dict) - if episode_id == 0: - # hack to initialize tensordict data structure to store episodes - td_data = ( - episode[0] - .expand(total_frames) - .memmap_like(self.root / f"{self.dataset_id}" / "replay_buffer") - ) + assert isinstance(episode_id, int) + self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1) + assert len(self.data_ids_per_episode[episode_id]) == num_frames - td_data[idx0:idx1] = episode - - episode_id += 1 idx0 = idx1 + episode_id += 1 - return TensorStorage(td_data.lock_()) + self.data_dict = {} + + keys = ep_dicts[0].keys() + for key in keys: + self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + + self.data_dict["index"] = torch.arange(0, total_frames, 1) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 0ad43a65..c8840169 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -3,6 +3,7 @@ import zipfile from pathlib import Path import requests +import torch import tqdm @@ -28,3 +29,71 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool: return True else: return False + + +def euclidean_distance_matrix(mat0, mat1): + # Compute the square of the distance matrix + sq0 = torch.sum(mat0**2, dim=1, keepdim=True) + sq1 = torch.sum(mat1**2, dim=1, keepdim=True) + distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1) + + # Taking the square root to get the euclidean distance + distance = torch.sqrt(torch.clamp(distance_sq, min=0)) + return distance + + +def is_contiguously_true_or_false(bool_vector): + assert bool_vector.ndim == 1 + assert bool_vector.dtype == torch.bool + + # Compare each element with its neighbor to find changes + changes = bool_vector[1:] != bool_vector[:-1] + + # Count the number of changes + num_changes = changes.sum().item() + + # If there's more than one change, the list is not contiguous + return num_changes <= 1 + + # examples = [ + # ([True, False, True, False, False, False], False), + # ([True, True, True, False, False, False], True), + # ([False, False, False, False, False, False], True) + # ] + # for bool_list, expected in examples: + # result = is_contiguously_true_or_false(bool_list) + + +def load_data_with_delta_timestamps( + data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode +): + # get indices of the frames associated to the episode, and their timestamps + ep_data_ids = data_ids_per_episode[episode] + ep_timestamps = data_dict["timestamp"][ep_data_ids] + + # get timestamps used as query to retrieve data of previous/future frames + delta_ts = delta_timestamps[key] + query_ts = current_ts + torch.tensor(delta_ts) + + # compute distances between each query timestamp and all timestamps of all the frames belonging to the episode + dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None]) + min_, argmin_ = dist.min(1) + + # get the indices of the data that are closest to the query timestamps + data_ids = ep_data_ids[argmin_] + # closest_ts = ep_timestamps[argmin_] + + # get the data + data = data_dict[key][data_ids].clone() + + # TODO(rcadene): synchronize timestamps + interpolation if needed + + tol = 0.02 + is_pad = min_ > tol + + assert is_contiguously_true_or_false(is_pad), ( + "One or several timestamps unexpectedly violate the tolerance." + "This might be due to synchronization issues with timestamps during data collection." + ) + + return data, is_pad diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 855e073b..788af3cb 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,64 +1,40 @@ -from torchrl.envs import SerialEnv -from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv +import gymnasium as gym -def make_env(cfg, transform=None): +def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv: """ - Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying - environments. The env therefore returns batches.` + Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and + returns batched observation, reward, terminated, truncated of `num_parallel_envs` items. """ - - kwargs = { - "frame_skip": cfg.env.action_repeat, - "from_pixels": cfg.env.from_pixels, - "pixels_only": cfg.env.pixels_only, - "image_size": cfg.env.image_size, - "num_prev_obs": cfg.n_obs_steps - 1, - } + kwargs = {} if cfg.env.name == "simxarm": - from lerobot.common.envs.simxarm.env import SimxarmEnv - kwargs["task"] = cfg.env.task - clsfunc = SimxarmEnv elif cfg.env.name == "pusht": - from lerobot.common.envs.pusht.env import PushtEnv + import gym_pusht # noqa # assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range." - - clsfunc = PushtEnv + kwargs.update( + { + "obs_type": "pixels_agent_pos", + "render_action": False, + } + ) + env_fn = lambda: gym.make( # noqa: E731 + "gym_pusht/PushTPixels-v0", + render_mode="rgb_array", + max_episode_steps=cfg.env.episode_length, + **kwargs, + ) elif cfg.env.name == "aloha": - from lerobot.common.envs.aloha.env import AlohaEnv - kwargs["task"] = cfg.env.task - clsfunc = AlohaEnv else: raise ValueError(cfg.env.name) - def _make_env(seed): - nonlocal kwargs - kwargs["seed"] = seed - env = clsfunc(**kwargs) - - # limit rollout to max_steps - env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) - - if transform is not None: - # useful to add normalization - if isinstance(transform, Compose): - for tf in transform: - env.append_transform(tf.clone()) - elif isinstance(transform, Transform): - env.append_transform(transform.clone()) - else: - raise NotImplementedError() - - return env - - return SerialEnv( - cfg.rollout_batch_size, - create_env_fn=_make_env, - create_env_kwargs=[ - {"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) - ], - ) + if num_parallel_envs == 0: + # non-batched version of the env that returns an observation of shape (c) + env = env_fn() + else: + # batched version of the env that returns an observation of shape (b, c) + env = gym.vector.SyncVectorEnv([env_fn for _ in range(num_parallel_envs)]) + return env diff --git a/lerobot/common/envs/simxarm/env.py b/lerobot/common/envs/simxarm/env.py index b81bf499..8ce6b24c 100644 --- a/lerobot/common/envs/simxarm/env.py +++ b/lerobot/common/envs/simxarm/env.py @@ -55,7 +55,7 @@ class SimxarmEnv(AbstractEnv): if not _has_gym: raise ImportError("Cannot import gymnasium.") - import gymnasium + import gymnasium as gym from lerobot.common.envs.simxarm.simxarm import TASKS @@ -65,7 +65,7 @@ class SimxarmEnv(AbstractEnv): self._env = TASKS[self.task]["env"]() num_actions = len(TASKS[self.task]["action_space"]) - self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,)) + self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,)) self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32) if "w" not in TASKS[self.task]["action_space"]: self._action_padding[-1] = 1.0 diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 82f39b28..a0fe0eba 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -1,18 +1,20 @@ import copy import logging import time +from collections import deque import hydra import torch +from torch import nn -from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder +from lerobot.common.policies.utils import populate_queues from lerobot.common.utils import get_safe_torch_device -class DiffusionPolicy(AbstractPolicy): +class DiffusionPolicy(nn.Module): name = "diffusion" def __init__( @@ -38,8 +40,12 @@ class DiffusionPolicy(AbstractPolicy): # parameters passed to step **kwargs, ): - super().__init__(n_action_steps) + super().__init__() self.cfg = cfg + self.n_obs_steps = n_obs_steps + self.n_action_steps = n_action_steps + # queues are populated during rollout of the policy, they contain the n latest observations and actions + self._queues = None noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape) @@ -100,76 +106,58 @@ class DiffusionPolicy(AbstractPolicy): last_epoch=self.global_step - 1, ) + def reset(self): + """ + Clear observation and action queues. Should be called on `env.reset()` + """ + self._queues = { + "observation.image": deque(maxlen=self.n_obs_steps), + "observation.state": deque(maxlen=self.n_obs_steps), + "action": deque(maxlen=self.n_action_steps), + } + @torch.no_grad() - def select_actions(self, observation, step_count): + def select_action(self, batch, step): """ Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights. """ - # TODO(rcadene): remove unused step_count - del step_count + # TODO(rcadene): remove unused step + del step + assert "observation.image" in batch + assert "observation.state" in batch + assert len(batch) == 2 - obs_dict = { - "image": observation["image"], - "agent_pos": observation["state"], - } - if self.training: - out = self.diffusion.predict_action(obs_dict) - else: - out = self.ema_diffusion.predict_action(obs_dict) - action = out["action"] + self._queues = populate_queues(self._queues, batch) + + if len(self._queues["action"]) == 0: + # stack n latest observations from the queue + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} + + obs_dict = { + "image": batch["observation.image"], + "agent_pos": batch["observation.state"], + } + if self.training: + out = self.diffusion.predict_action(obs_dict) + else: + out = self.ema_diffusion.predict_action(obs_dict) + self._queues["action"].extend(out["action"].transpose(0, 1)) + + action = self._queues["action"].popleft() return action - def update(self, replay_buffer, step): + def forward(self, batch, step): start_time = time.time() self.diffusion.train() - num_slices = self.cfg.batch_size - batch_size = self.cfg.horizon * num_slices - - assert batch_size % self.cfg.horizon == 0 - assert batch_size % num_slices == 0 - - def process_batch(batch, horizon, num_slices): - # trajectory t = 64, horizon h = 16 - # (t h) ... -> t h ... - batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous() - - # |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16 - # |o|o| observations: 2 - # | |a|a|a|a|a|a|a|a| actions executed: 8 - # |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16 - # note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model - - image = batch["observation", "image"] - state = batch["observation", "state"] - action = batch["action"] - assert image.shape[1] == horizon - assert state.shape[1] == horizon - assert action.shape[1] == horizon - - if not (horizon == 16 and self.cfg.n_obs_steps == 2): - raise NotImplementedError() - - # keep first 2 observations of the slice corresponding to t=[-1,0] - image = image[:, : self.cfg.n_obs_steps] - state = state[:, : self.cfg.n_obs_steps] - - out = { - "obs": { - "image": image.to(self.device, non_blocking=True), - "agent_pos": state.to(self.device, non_blocking=True), - }, - "action": action.to(self.device, non_blocking=True), - } - return out - - batch = replay_buffer.sample(batch_size) - batch = process_batch(batch, self.cfg.horizon, num_slices) - data_s = time.time() - start_time - loss = self.diffusion.compute_loss(batch) + obs_dict = { + "image": batch["observation.image"], + "agent_pos": batch["observation.state"], + } + loss = self.diffusion.compute_loss(obs_dict) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( diff --git a/lerobot/common/transforms.py b/lerobot/common/transforms.py index 4832c91b..ec967614 100644 --- a/lerobot/common/transforms.py +++ b/lerobot/common/transforms.py @@ -1,53 +1,49 @@ -from typing import Sequence - import torch -from tensordict import TensorDictBase -from tensordict.nn import dispatch -from tensordict.utils import NestedKey -from torchrl.envs.transforms import ObservationTransform, Transform +from torchvision.transforms.v2 import Compose, Transform -class Prod(ObservationTransform): +def apply_inverse_transform(item, transform): + transforms = transform.transforms if isinstance(transform, Compose) else [transform] + for tf in transforms[::-1]: + if tf.invertible: + item = tf.inverse_transform(item) + else: + raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).") + return item + + +class Prod(Transform): invertible = True - def __init__(self, in_keys: Sequence[NestedKey], prod: float): + def __init__(self, in_keys: list[str], prod: float): super().__init__() self.in_keys = in_keys self.prod = prod self.original_dtypes = {} - def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase: - # _reset is called once when the environment reset to normalize the first observation - tensordict_reset = self._call(tensordict_reset) - return tensordict_reset - - @dispatch(source="in_keys", dest="out_keys") - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - return self._call(tensordict) - - def _call(self, td): + def forward(self, item): for key in self.in_keys: - if td.get(key, None) is None: + if key not in item: continue - self.original_dtypes[key] = td[key].dtype - td[key] = td[key].type(torch.float32) * self.prod - return td + self.original_dtypes[key] = item[key].dtype + item[key] = item[key].type(torch.float32) * self.prod + return item - def _inv_call(self, td: TensorDictBase) -> TensorDictBase: + def inverse_transform(self, item): for key in self.in_keys: - if td.get(key, None) is None: + if key not in item: continue - td[key] = (td[key] / self.prod).type(self.original_dtypes[key]) - return td + item[key] = (item[key] / self.prod).type(self.original_dtypes[key]) + return item - def transform_observation_spec(self, obs_spec): - for key in self.in_keys: - if obs_spec.get(key, None) is None: - continue - obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod - obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod - obs_spec[key].dtype = torch.float32 - return obs_spec + # def transform_observation_spec(self, obs_spec): + # for key in self.in_keys: + # if obs_spec.get(key, None) is None: + # continue + # obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod + # obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod + # obs_spec[key].dtype = torch.float32 + # return obs_spec class NormalizeTransform(Transform): @@ -55,65 +51,50 @@ class NormalizeTransform(Transform): def __init__( self, - stats: TensorDictBase, - in_keys: Sequence[NestedKey] = None, - out_keys: Sequence[NestedKey] | None = None, - in_keys_inv: Sequence[NestedKey] | None = None, - out_keys_inv: Sequence[NestedKey] | None = None, + stats: dict, + in_keys: list[str] = None, + out_keys: list[str] | None = None, + in_keys_inv: list[str] | None = None, + out_keys_inv: list[str] | None = None, mode="mean_std", ): - if out_keys is None: - out_keys = in_keys - if in_keys_inv is None: - in_keys_inv = out_keys - if out_keys_inv is None: - out_keys_inv = in_keys - super().__init__( - in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv - ) + super().__init__() + self.in_keys = in_keys + self.out_keys = in_keys if out_keys is None else out_keys + self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv + self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv self.stats = stats assert mode in ["mean_std", "min_max"] self.mode = mode - def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase: - # _reset is called once when the environment reset to normalize the first observation - tensordict_reset = self._call(tensordict_reset) - return tensordict_reset - - @dispatch(source="in_keys", dest="out_keys") - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - return self._call(tensordict) - - def _call(self, td: TensorDictBase) -> TensorDictBase: + def forward(self, item): for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False): - # TODO(rcadene): don't know how to do `inkey not in td` - if td.get(inkey, None) is None: + if inkey not in item: continue if self.mode == "mean_std": mean = self.stats[inkey]["mean"] std = self.stats[inkey]["std"] - td[outkey] = (td[inkey] - mean) / (std + 1e-8) + item[outkey] = (item[inkey] - mean) / (std + 1e-8) else: min = self.stats[inkey]["min"] max = self.stats[inkey]["max"] # normalize to [0,1] - td[outkey] = (td[inkey] - min) / (max - min) + item[outkey] = (item[inkey] - min) / (max - min) # normalize to [-1, 1] - td[outkey] = td[outkey] * 2 - 1 - return td + item[outkey] = item[outkey] * 2 - 1 + return item - def _inv_call(self, td: TensorDictBase) -> TensorDictBase: + def inverse_transform(self, item): for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False): - # TODO(rcadene): don't know how to do `inkey not in td` - if td.get(inkey, None) is None: + if inkey not in item: continue if self.mode == "mean_std": mean = self.stats[inkey]["mean"] std = self.stats[inkey]["std"] - td[outkey] = td[inkey] * std + mean + item[outkey] = item[inkey] * std + mean else: min = self.stats[inkey]["min"] max = self.stats[inkey]["max"] - td[outkey] = (td[inkey] + 1) / 2 - td[outkey] = td[outkey] * (max - min) + min - return td + item[outkey] = (item[inkey] + 1) / 2 + item[outkey] = item[outkey] * (max - min) + min + return item diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 216769d6..fe0f7bb2 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -36,111 +36,196 @@ from datetime import datetime as dt from pathlib import Path import einops +import gymnasium as gym +import hydra import imageio import numpy as np import torch -import tqdm from huggingface_hub import snapshot_download -from tensordict.nn import TensorDictModule -from torchrl.envs import EnvBase -from torchrl.envs.batched_envs import BatchedEnvBase -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.factory import make_env from lerobot.common.logger import log_output_dir -from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.factory import make_policy from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed +from lerobot.common.transforms import apply_inverse_transform def write_video(video_path, stacked_frames, fps): imageio.mimsave(video_path, stacked_frames, fps=fps) +def preprocess_observation(observation, transform=None): + # map to expected inputs for the policy + obs = { + "observation.image": torch.from_numpy(observation["pixels"]).float(), + "observation.state": torch.from_numpy(observation["agent_pos"]).float(), + } + # convert to (b c h w) torch format + obs["observation.image"] = einops.rearrange(obs["observation.image"], "b h w c -> b c h w") + + # apply same transforms as in training + if transform is not None: + for key in obs: + obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]]) + + return obs + + +def postprocess_action(action, transform=None): + action = action.to("cpu") + # action is a batch (num_env,action_dim) instead of an item (action_dim), + # we assume applying inverse transform on a batch works the same + action = apply_inverse_transform({"action": action}, transform)["action"].numpy() + assert ( + action.ndim == 2 + ), "we assume dimensions are respectively the number of parallel envs, action dimensions" + return action + + def eval_policy( - env: BatchedEnvBase, - policy: AbstractPolicy, - num_episodes: int = 10, - max_steps: int = 30, + env: gym.vector.VectorEnv, + policy, save_video: bool = False, video_dir: Path = None, + # TODO(rcadene): make it possible to overwrite fps? we should use env.fps fps: int = 15, return_first_video: bool = False, + transform: callable = None, ): if policy is not None: policy.eval() start = time.time() sum_rewards = [] max_rewards = [] - successes = [] + all_successes = [] seeds = [] threads = [] # for video saving threads episode_counter = 0 # for saving the correct number of videos + num_episodes = len(env.envs) + # TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than # needed as I'm currently taking a ceil. - for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))): - ep_frames = [] + ep_frames = [] - def maybe_render_frame(env: EnvBase, _): - if save_video or (return_first_video and i == 0): # noqa: B023 - ep_frames.append(env.render()) # noqa: B023 + def maybe_render_frame(env): + if save_video: # noqa: B023 + if return_first_video: + visu = env.envs[0].render() + visu = visu[None, ...] # add batch dim + else: + visu = np.stack([env.render() for env in env.envs]) + ep_frames.append(visu) # noqa: B023 - # Clear the policy's action queue before the start of a new rollout. - if policy is not None: - policy.clear_action_queue() + for _ in range(num_episodes): + seeds.append("TODO") - if env.is_closed: - env.start() # needed to be able to get the seeds the first time as BatchedEnvs are lazy - seeds.extend(env._next_seed) + if hasattr(policy, "reset"): + policy.reset() + else: + logging.warning( + f"Policy {policy} doesnt have a `reset` method. This find if the policy doesnt rely on an internal state during rollout." + ) + + # reset the environment + observation, info = env.reset(seed=cfg.seed) + maybe_render_frame(env) + + rewards = [] + successes = [] + dones = [] + + done = torch.tensor([False for _ in env.envs]) + step = 0 + do_rollout = True + while do_rollout: + # apply transform to normalize the observations + observation = preprocess_observation(observation, transform) + + # send observation to device/gpu + observation = {key: observation[key].to(cfg.device, non_blocking=True) for key in observation} + + # get the next action for the environment with torch.inference_mode(): - # TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all - # envs are done the first time. But we only use the first rollout. This is a waste of compute. - rollout = env.rollout( - max_steps=max_steps, - policy=policy, - auto_cast_to_device=True, - callback=maybe_render_frame, - break_when_any_done=env.batch_size[0] == 1, - ) - # Figure out where in each rollout sequence the first done condition was encountered (results after - # this won't be included). - # Note: this assumes that the shape of the done key is (batch_size, max_steps, 1). - # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. - rollout_steps = rollout["next", "done"].shape[1] - done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps) - mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1) - batch_sum_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "sum") - batch_max_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "max") - batch_success = einops.reduce((rollout["next", "success"] * mask), "b n 1 -> b", "any") - sum_rewards.extend(batch_sum_reward.tolist()) - max_rewards.extend(batch_max_reward.tolist()) - successes.extend(batch_success.tolist()) + action = policy.select_action(observation, step) - if save_video or (return_first_video and i == 0): - batch_stacked_frames = np.stack(ep_frames) # (t, b, *) - batch_stacked_frames = batch_stacked_frames.transpose( - 1, 0, *range(2, batch_stacked_frames.ndim) - ) # (b, t, *) + # apply inverse transform to unnormalize the action + action = postprocess_action(action, transform) - if save_video: - for stacked_frames, done_index in zip( - batch_stacked_frames, done_indices.flatten().tolist(), strict=False - ): - if episode_counter >= num_episodes: - continue - video_dir.mkdir(parents=True, exist_ok=True) - video_path = video_dir / f"eval_episode_{episode_counter}.mp4" - thread = threading.Thread( - target=write_video, - args=(str(video_path), stacked_frames[:done_index], fps), - ) - thread.start() - threads.append(thread) - episode_counter += 1 + # apply the next + observation, reward, terminated, truncated, info = env.step(action) + maybe_render_frame(env) - if return_first_video and i == 0: - first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2) + # TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?) + reward = torch.from_numpy(reward) + terminated = torch.from_numpy(terminated) + truncated = torch.from_numpy(truncated) + # environment is considered done (no more steps), when success state is reached (terminated is True), + # or time limit is reached (truncated is True), or it was previsouly done. + done = terminated | truncated | done + + if "final_info" in info: + # VectorEnv stores is_success into `info["final_info"][env_id]["is_success"]` instead of `info["is_success"]` + success = [ + env_info["is_success"] if env_info is not None else False for env_info in info["final_info"] + ] + else: + success = [False for _ in env.envs] + success = torch.tensor(success) + + rewards.append(reward) + dones.append(done) + successes.append(success) + + step += 1 + + if done.all(): + do_rollout = False + break + + rewards = torch.stack(rewards, dim=1) + successes = torch.stack(successes, dim=1) + dones = torch.stack(dones, dim=1) + + # Figure out where in each rollout sequence the first done condition was encountered (results after + # this won't be included). + # Note: this assumes that the shape of the done key is (batch_size, max_steps). + # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. + done_indices = torch.argmax(dones.to(int), axis=1) # (batch_size, rollout_steps) + expand_done_indices = done_indices[:, None].expand(-1, step) + expand_step_indices = torch.arange(step)[None, :].expand(num_episodes, -1) + mask = (expand_step_indices <= expand_done_indices).int() # (batch_size, rollout_steps) + batch_sum_reward = einops.reduce((rewards * mask), "b n -> b", "sum") + batch_max_reward = einops.reduce((rewards * mask), "b n -> b", "max") + batch_success = einops.reduce((successes * mask), "b n -> b", "any") + sum_rewards.extend(batch_sum_reward.tolist()) + max_rewards.extend(batch_max_reward.tolist()) + all_successes.extend(batch_success.tolist()) + + env.close() + + if save_video or return_first_video: + batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *) + + if save_video: + for stacked_frames, done_index in zip( + batch_stacked_frames, done_indices.flatten().tolist(), strict=False + ): + if episode_counter >= num_episodes: + continue + video_dir.mkdir(parents=True, exist_ok=True) + video_path = video_dir / f"eval_episode_{episode_counter}.mp4" + thread = threading.Thread( + target=write_video, + args=(str(video_path), stacked_frames[:done_index], fps), + ) + thread.start() + threads.append(thread) + episode_counter += 1 + + if return_first_video: + first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2) for thread in threads: thread.join() @@ -158,16 +243,16 @@ def eval_policy( zip( sum_rewards[:num_episodes], max_rewards[:num_episodes], - successes[:num_episodes], + all_successes[:num_episodes], seeds[:num_episodes], strict=True, ) ) ], "aggregated": { - "avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]), - "avg_max_reward": np.nanmean(max_rewards[:num_episodes]), - "pc_success": np.nanmean(successes[:num_episodes]) * 100, + "avg_sum_reward": float(np.nanmean(sum_rewards[:num_episodes])), + "avg_max_reward": float(np.nanmean(max_rewards[:num_episodes])), + "pc_success": float(np.nanmean(all_successes[:num_episodes]) * 100), "eval_s": time.time() - start, "eval_ep_s": (time.time() - start) / num_episodes, }, @@ -194,21 +279,13 @@ def eval(cfg: dict, out_dir=None, stats_path=None): logging.info("Making transforms.") # TODO(alexander-soare): Completely decouple datasets from evaluation. - offline_buffer = make_offline_buffer(cfg, stats_path=stats_path) + dataset = make_dataset(cfg, stats_path=stats_path) logging.info("Making environment.") - env = make_env(cfg, transform=offline_buffer.transform) + env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) - if cfg.policy.pretrained_model_path: - policy = make_policy(cfg) - policy = TensorDictModule( - policy, - in_keys=["observation", "step_count"], - out_keys=["action"], - ) - else: - # when policy is None, rollout a random policy - policy = None + # when policy is None, rollout a random policy + policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None info = eval_policy( env, @@ -216,8 +293,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None): save_video=True, video_dir=Path(out_dir) / "eval", fps=cfg.env.fps, - max_steps=cfg.env.episode_length, - num_episodes=cfg.eval_episodes, + # TODO(rcadene): what should we do with the transform? + transform=dataset.transform, ) print(info["aggregated"]) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 18c3715b..5e9cd361 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,14 +1,12 @@ import logging +from itertools import cycle from pathlib import Path import hydra import numpy as np import torch -from tensordict.nn import TensorDictModule -from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer -from torchrl.data.replay_buffers import PrioritizedSliceSampler -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy @@ -34,7 +32,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa train(cfg, out_dir=out_dir, job_name=job_name) -def log_train_info(logger, info, step, cfg, offline_buffer, is_offline): +def log_train_info(logger, info, step, cfg, dataset, is_offline): loss = info["loss"] grad_norm = info["grad_norm"] lr = info["lr"] @@ -44,9 +42,9 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline): # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. num_samples = (step + 1) * cfg.policy.batch_size - avg_samples_per_ep = offline_buffer.num_samples / offline_buffer.num_episodes + avg_samples_per_ep = dataset.num_samples / dataset.num_episodes num_episodes = num_samples / avg_samples_per_ep - num_epochs = num_samples / offline_buffer.num_samples + num_epochs = num_samples / dataset.num_samples log_items = [ f"step:{format_big_number(step)}", # number of samples seen during training @@ -73,7 +71,7 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline): logger.log_dict(info, step, mode="train") -def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline): +def log_eval_info(logger, info, step, cfg, dataset, is_offline): eval_s = info["eval_s"] avg_sum_reward = info["avg_sum_reward"] pc_success = info["pc_success"] @@ -81,9 +79,9 @@ def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline): # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. num_samples = (step + 1) * cfg.policy.batch_size - avg_samples_per_ep = offline_buffer.num_samples / offline_buffer.num_episodes + avg_samples_per_ep = dataset.num_samples / dataset.num_episodes num_episodes = num_samples / avg_samples_per_ep - num_epochs = num_samples / offline_buffer.num_samples + num_epochs = num_samples / dataset.num_samples log_items = [ f"step:{format_big_number(step)}", # number of samples seen during training @@ -124,30 +122,30 @@ def train(cfg: dict, out_dir=None, job_name=None): torch.backends.cuda.matmul.allow_tf32 = True set_global_seed(cfg.seed) - logging.info("make_offline_buffer") - offline_buffer = make_offline_buffer(cfg) + logging.info("make_dataset") + dataset = make_dataset(cfg) # TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy - if cfg.policy.balanced_sampling: - logging.info("make online_buffer") - num_traj_per_batch = cfg.policy.batch_size + # if cfg.policy.balanced_sampling: + # logging.info("make online_buffer") + # num_traj_per_batch = cfg.policy.batch_size - online_sampler = PrioritizedSliceSampler( - max_capacity=100_000, - alpha=cfg.policy.per_alpha, - beta=cfg.policy.per_beta, - num_slices=num_traj_per_batch, - strict_length=True, - ) + # online_sampler = PrioritizedSliceSampler( + # max_capacity=100_000, + # alpha=cfg.policy.per_alpha, + # beta=cfg.policy.per_beta, + # num_slices=num_traj_per_batch, + # strict_length=True, + # ) - online_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(100_000), - sampler=online_sampler, - transform=offline_buffer.transform, - ) + # online_buffer = TensorDictReplayBuffer( + # storage=LazyMemmapStorage(100_000), + # sampler=online_sampler, + # transform=dataset.transform, + # ) logging.info("make_env") - env = make_env(cfg, transform=offline_buffer.transform) + env = make_env(cfg) logging.info("make_policy") policy = make_policy(cfg) @@ -155,8 +153,6 @@ def train(cfg: dict, out_dir=None, job_name=None): num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - td_policy = TensorDictModule(policy, in_keys=["observation", "step_count"], out_keys=["action"]) - # log metrics to terminal and wandb logger = Logger(out_dir, job_name, cfg) @@ -165,8 +161,8 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.online_steps=}") logging.info(f"{cfg.env.action_repeat=}") - logging.info(f"{offline_buffer.num_samples=} ({format_big_number(offline_buffer.num_samples)})") - logging.info(f"{offline_buffer.num_episodes=}") + logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})") + logging.info(f"{dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") @@ -176,14 +172,15 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info(f"Eval policy at step {step}") eval_info, first_video = eval_policy( env, - td_policy, + policy, num_episodes=cfg.eval_episodes, max_steps=cfg.env.episode_length, return_first_video=True, video_dir=Path(out_dir) / "eval", save_video=True, + transform=dataset.transform, ) - log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_buffer, is_offline) + log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline) if cfg.wandb.enable: logger.log_video(first_video, step, mode="eval") logging.info("Resume training") @@ -196,14 +193,29 @@ def train(cfg: dict, out_dir=None, job_name=None): step = 0 # number of policy update (forward + backward + optim) is_offline = True + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=cfg.policy.batch_size, + shuffle=True, + pin_memory=cfg.device != "cpu", + drop_last=True, + ) + dl_iter = cycle(dataloader) for offline_step in range(cfg.offline_steps): if offline_step == 0: logging.info("Start offline training on a fixed dataset") - # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? policy.train() - train_info = policy.update(offline_buffer, step) + batch = next(dl_iter) + + for key in batch: + batch[key] = batch[key].to(cfg.device, non_blocking=True) + + train_info = policy.update(batch, step) + + # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: - log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) + log_train_info(logger, train_info, step, cfg, dataset, is_offline) # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in # step + 1. @@ -211,7 +223,7 @@ def train(cfg: dict, out_dir=None, job_name=None): step += 1 - demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None + demo_buffer = dataset if cfg.policy.balanced_sampling else None online_step = 0 is_offline = False for env_step in range(cfg.online_steps): @@ -221,7 +233,7 @@ def train(cfg: dict, out_dir=None, job_name=None): with torch.no_grad(): rollout = env.rollout( max_steps=cfg.env.episode_length, - policy=td_policy, + policy=policy, auto_cast_to_device=True, ) @@ -242,7 +254,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # set same episode index for all time steps contained in this rollout rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) - online_buffer.extend(rollout) + # online_buffer.extend(rollout) ep_sum_reward = rollout["next", "reward"].sum() ep_max_reward = rollout["next", "reward"].max() @@ -257,13 +269,13 @@ def train(cfg: dict, out_dir=None, job_name=None): for _ in range(cfg.policy.utd): train_info = policy.update( - online_buffer, + # online_buffer, step, demo_buffer=demo_buffer, ) if step % cfg.log_freq == 0: train_info.update(rollout_info) - log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) + log_train_info(logger, train_info, step, cfg, dataset, is_offline) # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass # in step + 1. diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 3dd7cdfa..93315e90 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -10,7 +10,7 @@ from torchrl.data.replay_buffers import ( SamplerWithoutReplacement, ) -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.logger import log_output_dir from lerobot.common.utils import init_logging @@ -44,8 +44,8 @@ def visualize_dataset(cfg: dict, out_dir=None): shuffle=False, ) - logging.info("make_offline_buffer") - offline_buffer = make_offline_buffer( + logging.info("make_dataset") + dataset = make_dataset( cfg, overwrite_sampler=sampler, # remove all transformations such as rescale images from [0,255] to [0,1] or normalization @@ -55,12 +55,12 @@ def visualize_dataset(cfg: dict, out_dir=None): ) logging.info("Start rendering episodes from offline buffer") - video_paths = render_dataset(offline_buffer, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps) + video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps) for video_path in video_paths: logging.info(video_path) -def render_dataset(offline_buffer, out_dir, max_num_samples, fps): +def render_dataset(dataset, out_dir, max_num_samples, fps): out_dir = Path(out_dir) video_paths = [] threads = [] @@ -69,17 +69,17 @@ def render_dataset(offline_buffer, out_dir, max_num_samples, fps): logging.info(f"Visualizing episode {current_ep_idx}") for i in range(max_num_samples): # TODO(rcadene): make it work with bsize > 1 - ep_td = offline_buffer.sample(1) + ep_td = dataset.sample(1) ep_idx = ep_td["episode"][FIRST_FRAME].item() - # TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames - num_frames_left = offline_buffer._sampler._sample_list.numel() + # TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames + num_frames_left = dataset._sampler._sample_list.numel() episode_is_done = ep_idx != current_ep_idx if episode_is_done: logging.info(f"Rendering episode {current_ep_idx}") - for im_key in offline_buffer.image_keys: + for im_key in dataset.image_keys: if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1): # when first frame of episode, initialize frames dict if im_key not in frames: @@ -93,7 +93,7 @@ def render_dataset(offline_buffer, out_dir, max_num_samples, fps): frames[im_key].append(ep_td["next"][im_key]) out_dir.mkdir(parents=True, exist_ok=True) - if len(offline_buffer.image_keys) > 1: + if len(dataset.image_keys) > 1: camera = im_key[-1] video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4" else: diff --git a/poetry.lock b/poetry.lock index 72397001..8fb6b7a7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -879,6 +879,29 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.62.1)"] +[[package]] +name = "gym-pusht" +version = "0.1.0" +description = "PushT environment for LeRobot" +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +gymnasium = "^0.29.1" +opencv-python = "^4.9.0.80" +pygame = "^2.5.2" +pymunk = "^6.6.0" +scikit-image = "^0.22.0" +shapely = "^2.0.3" + +[package.source] +type = "git" +url = "git@github.com:huggingface/gym-pusht.git" +reference = "HEAD" +resolved_reference = "d7e1a39a31b1368741e9674791007d7cccf046a3" + [[package]] name = "gymnasium" version = "0.29.1" @@ -3586,7 +3609,10 @@ files = [ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +[extras] +pusht = ["gym_pusht"] + [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "174c7d42f8039eedd2c447a4e6cae5169782cbd94346b5606572a0010194ca05" +content-hash = "3eee17e4bf2b7a570f41ef9c400ec5a24a3113f62a13162229cf43504ca0d005" diff --git a/pyproject.toml b/pyproject.toml index 972c1b61..d0fc7c0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,10 @@ robomimic = "0.2.0" gymnasium-robotics = "^1.2.4" gymnasium = "^0.29.1" cmake = "^3.29.0.1" +gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true} +[tool.poetry.extras] +pusht = ["gym_pusht"] [tool.poetry.group.dev.dependencies] pre-commit = "^3.6.2" diff --git a/tests/test_datasets.py b/tests/test_datasets.py index df41b03f..f7f80a42 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -6,6 +6,8 @@ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.utils import init_hydra_config +import logging +from lerobot.common.datasets.factory import make_dataset from .utils import DEVICE, DEFAULT_CONFIG_PATH @@ -26,14 +28,29 @@ def test_factory(env_name, dataset_id): DEFAULT_CONFIG_PATH, overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"] ) - offline_buffer = make_offline_buffer(cfg) - for key in offline_buffer.image_keys: - img = offline_buffer[0].get(key) + dataset = make_dataset(cfg) + + item = dataset[0] + + assert "action" in item + assert "episode" in item + assert "frame_id" in item + assert "timestamp" in item + assert "next.done" in item + # TODO(rcadene): should we rename it agent_pos? + assert "observation.state" in item + for key in dataset.image_keys: + img = item.get(key) assert img.dtype == torch.float32 # TODO(rcadene): we assume for now that image normalization takes place in the model assert img.max() <= 1.0 assert img.min() >= 0.0 + if "next.reward" not in item: + logging.warning(f'Missing "next.reward" key in dataset {dataset}.') + if "next.done" not in item: + logging.warning(f'Missing "next.done" key in dataset {dataset}.') + def test_compute_stats(): """Check that the statistics are computed correctly according to the stats_patterns property. diff --git a/tests/test_envs.py b/tests/test_envs.py index eb3746db..0c56f4fc 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -2,7 +2,7 @@ import pytest from tensordict import TensorDict import torch from torchrl.envs.utils import check_env_specs, step_mdp -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.aloha.env import AlohaEnv from lerobot.common.envs.factory import make_env @@ -116,15 +116,15 @@ def test_factory(env_name): overrides=[f"env={env_name}", f"device={DEVICE}"], ) - offline_buffer = make_offline_buffer(cfg) + dataset = make_dataset(cfg) env = make_env(cfg) - for key in offline_buffer.image_keys: + for key in dataset.image_keys: assert env.reset().get(key).dtype == torch.uint8 check_env_specs(env) - env = make_env(cfg, transform=offline_buffer.transform) - for key in offline_buffer.image_keys: + env = make_env(cfg, transform=dataset.transform) + for key in dataset.image_keys: img = env.reset().get(key) assert img.dtype == torch.float32 # TODO(rcadene): we assume for now that image normalization takes place in the model diff --git a/tests/test_policies.py b/tests/test_policies.py index 5d6b46d0..a46c6025 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -7,7 +7,7 @@ from torchrl.envs import EnvBase from lerobot.common.policies.factory import make_policy from lerobot.common.envs.factory import make_env -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.utils import init_hydra_config from .utils import DEVICE, DEFAULT_CONFIG_PATH @@ -45,13 +45,13 @@ def test_concrete_policy(env_name, policy_name, extra_overrides): # Check that we can make the policy object. policy = make_policy(cfg) # Check that we run select_actions and get the appropriate output. - offline_buffer = make_offline_buffer(cfg) - env = make_env(cfg, transform=offline_buffer.transform) + dataset = make_dataset(cfg) + env = make_env(cfg, transform=dataset.transform) if env_name != "aloha": # TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError: # seq_length as a list is not supported for now. - policy.update(offline_buffer, torch.tensor(0, device=DEVICE)) + policy.update(dataset, torch.tensor(0, device=DEVICE)) action = policy( env.observation_spec.rand()["observation"].to(DEVICE), From c93ce35d8c403db9933ae3bcf1fe23683e485d99 Mon Sep 17 00:00:00 2001 From: Cadene Date: Thu, 4 Apr 2024 16:36:03 +0000 Subject: [PATCH 2/3] WIP stats (TODO: run tests on stats + cmpute them) --- lerobot/common/datasets/abstract.py | 234 ---------------------------- lerobot/common/datasets/factory.py | 32 ++-- lerobot/common/datasets/utils.py | 101 ++++++++++++ lerobot/common/transforms.py | 16 +- tests/test_datasets.py | 60 ++++--- 5 files changed, 157 insertions(+), 286 deletions(-) delete mode 100644 lerobot/common/datasets/abstract.py diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py deleted file mode 100644 index e9e9c610..00000000 --- a/lerobot/common/datasets/abstract.py +++ /dev/null @@ -1,234 +0,0 @@ -import logging -from copy import deepcopy -from math import ceil -from pathlib import Path -from typing import Callable - -import einops -import torch -import torchrl -import tqdm -from huggingface_hub import snapshot_download -from tensordict import TensorDict -from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement -from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id -from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer -from torchrl.envs.transforms.transforms import Compose - -HF_USER = "lerobot" - - -class AbstractDataset(TensorDictReplayBuffer): - """ - AbstractDataset represents a dataset in the context of imitation learning or reinforcement learning. - This class is designed to be subclassed by concrete implementations that specify particular types of datasets. - These implementations can vary based on the source of the data, the environment the data pertains to, - or the specific kind of data manipulation applied. - - Note: - - `TensorDictReplayBuffer` is the base class from which `AbstractDataset` inherits. It provides the foundational - functionality for storing and retrieving `TensorDict`-like data. - - `available_datasets` should be overridden by concrete subclasses to list the specific dataset variants supported. - It is expected that these variants correspond to a HuggingFace dataset on the hub. - For instance, the `AlohaDataset` which inherites from `AbstractDataset` has 4 available dataset variants: - - [aloha_sim_transfer_cube_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted) - - [aloha_sim_insertion_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted) - - [aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) - - [aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) - - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - 1. set the required class attributes: - - for classes inheriting from `AbstractDataset`: `available_datasets` - - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - - for classes inheriting from `AbstractPolicy`: `name` - 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) - 3. update variables in `tests/test_available.py` by importing your new class - """ - - available_datasets: list[str] | None = None - - def __init__( - self, - dataset_id: str, - version: str | None = None, - batch_size: int | None = None, - *, - shuffle: bool = True, - root: Path | None = None, - pin_memory: bool = False, - prefetch: int = None, - sampler: Sampler | None = None, - collate_fn: Callable | None = None, - writer: Writer | None = None, - transform: "torchrl.envs.Transform" = None, - ): - assert ( - self.available_datasets is not None - ), "Subclasses of `AbstractDataset` should set the `available_datasets` class attribute." - assert ( - dataset_id in self.available_datasets - ), f"The provided dataset ({dataset_id}) is not on the list of available datasets {self.available_datasets}." - - self.dataset_id = dataset_id - self.version = version - self.shuffle = shuffle - self.root = root if root is None else Path(root) - - if self.root is not None and self.version is not None: - logging.warning( - f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})." - ) - - storage = self._download_or_load_dataset() - - super().__init__( - storage=storage, - sampler=sampler, - writer=ImmutableDatasetWriter() if writer is None else writer, - collate_fn=_collate_id if collate_fn is None else collate_fn, - pin_memory=pin_memory, - prefetch=prefetch, - batch_size=batch_size, - transform=transform, - ) - - @property - def stats_patterns(self) -> dict: - return { - ("observation", "state"): "b c -> c", - ("observation", "image"): "b c h w -> c 1 1", - ("action",): "b c -> c", - } - - @property - def image_keys(self) -> list: - return [("observation", "image")] - - @property - def num_cameras(self) -> int: - return len(self.image_keys) - - @property - def num_samples(self) -> int: - return len(self) - - @property - def num_episodes(self) -> int: - return len(self._storage._storage["episode"].unique()) - - @property - def transform(self): - return self._transform - - def set_transform(self, transform): - if not isinstance(transform, Compose): - # required since torchrl calls `len(self._transform)` downstream - if isinstance(transform, list): - self._transform = Compose(*transform) - else: - self._transform = Compose(transform) - else: - self._transform = transform - - def compute_or_load_stats(self, batch_size: int = 32) -> TensorDict: - stats_path = self.data_dir / "stats.pth" - if stats_path.exists(): - stats = torch.load(stats_path) - else: - logging.info(f"compute_stats and save to {stats_path}") - stats = self._compute_stats(batch_size) - torch.save(stats, stats_path) - return stats - - def _download_or_load_dataset(self) -> torch.StorageBase: - if self.root is None: - self.data_dir = Path( - snapshot_download( - repo_id=f"{HF_USER}/{self.dataset_id}", repo_type="dataset", revision=self.version - ) - ) - else: - self.data_dir = self.root / self.dataset_id - return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer")) - - def _compute_stats(self, batch_size: int = 32): - """Compute dataset statistics including minimum, maximum, mean, and standard deviation. - - TODO(alexander-soare): Add a num_batches argument which essentially allows one to use a subset of the - full dataset (for handling very large datasets). The sampling would then have to be random - (preferably without replacement). Both stats computation loops would ideally sample the same - items. - """ - rb = TensorDictReplayBuffer( - storage=self._storage, - batch_size=32, - prefetch=True, - # Note: Due to be refactored soon. The point is that we should go through the whole dataset. - sampler=SamplerWithoutReplacement(drop_last=False, shuffle=False), - ) - - # mean and std will be computed incrementally while max and min will track the running value. - mean, std, max, min = {}, {}, {}, {} - for key in self.stats_patterns: - mean[key] = torch.tensor(0.0).float() - std[key] = torch.tensor(0.0).float() - max[key] = torch.tensor(-float("inf")).float() - min[key] = torch.tensor(float("inf")).float() - - # Compute mean, min, max. - # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get - # surprises when rerunning the sampler. - first_batch = None - running_item_count = 0 # for online mean computation - for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): - batch = rb.sample() - this_batch_size = batch.batch_size[0] - running_item_count += this_batch_size - if first_batch is None: - first_batch = deepcopy(batch) - for key, pattern in self.stats_patterns.items(): - batch[key] = batch[key].float() - # Numerically stable update step for mean computation. - batch_mean = einops.reduce(batch[key], pattern, "mean") - # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents - # the update step, N is the running item count, B is this batch size, x̄ is the running mean, - # and x is the current batch mean. Some rearrangement is then required to avoid risking - # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields - # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ - mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count - max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) - min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) - - # Compute std. - first_batch_ = None - running_item_count = 0 # for online std computation - for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): - batch = rb.sample() - this_batch_size = batch.batch_size[0] - running_item_count += this_batch_size - # Sanity check to make sure the batches are still in the same order as before. - if first_batch_ is None: - first_batch_ = deepcopy(batch) - for key in self.stats_patterns: - assert torch.equal(first_batch_[key], first_batch[key]) - for key, pattern in self.stats_patterns.items(): - batch[key] = batch[key].float() - # Numerically stable update step for mean computation (where the mean is over squared - # residuals).See notes in the mean computation loop above. - batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") - std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count - - for key in self.stats_patterns: - std[key] = torch.sqrt(std[key]) - - stats = TensorDict({}, batch_size=[]) - for key in self.stats_patterns: - stats[(*key, "mean")] = mean[key] - stats[(*key, "std")] = std[key] - stats[(*key, "max")] = max[key] - stats[(*key, "min")] = min[key] - - if key[0] == "observation": - # use same stats for the next observations - stats[("next", *key)] = stats[key] - return stats diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 94ac8ca4..32d76a50 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -4,7 +4,8 @@ from pathlib import Path import torch from torchvision.transforms import v2 -from lerobot.common.transforms import Prod +from lerobot.common.datasets.utils import compute_or_load_stats +from lerobot.common.transforms import NormalizeTransform, Prod # DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and # we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data` @@ -41,9 +42,8 @@ def make_dataset( # min_max_from_spec # stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path) - stats = {} - if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": + stats = {} # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this stats["observation.state"] = {} stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) @@ -51,22 +51,30 @@ def make_dataset( stats["action"] = {} stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) + else: + # instantiate a one frame dataset with light transform + stats_dataset = clsfunc( + dataset_id=cfg.dataset_id, + root=DATA_DIR, + transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), + ) + stats = compute_or_load_stats(stats_dataset) # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std - # normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" + normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" transforms = v2.Compose( [ # TODO(rcadene): we need to do something about image_keys Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), - # NormalizeTransform( - # stats, - # in_keys=[ - # "observation.state", - # "action", - # ], - # mode=normalization_mode, - # ), + NormalizeTransform( + stats, + in_keys=[ + "observation.state", + "action", + ], + mode=normalization_mode, + ), ] ) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index c8840169..522227d7 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -1,7 +1,11 @@ import io +import logging import zipfile +from copy import deepcopy +from math import ceil from pathlib import Path +import einops import requests import torch import tqdm @@ -97,3 +101,100 @@ def load_data_with_delta_timestamps( ) return data, is_pad + + +def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): + stats_path = dataset.data_dir / "stats.pth" + if stats_path.exists(): + return torch.load(stats_path) + + logging.info(f"compute_stats and save to {stats_path}") + + if max_num_samples is None: + max_num_samples = len(dataset) + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=batch_size, + shuffle=True, + # pin_memory=cfg.device != "cpu", + drop_last=False, + ) + + stats_patterns = { + "action": "b c -> c", + "observation.state": "b c -> c", + } + for key in dataset.image_keys: + stats_patterns[key] = "b c h w -> c 1 1" + + # mean and std will be computed incrementally while max and min will track the running value. + mean, std, max, min = {}, {}, {}, {} + for key in stats_patterns: + mean[key] = torch.tensor(0.0).float() + std[key] = torch.tensor(0.0).float() + max[key] = torch.tensor(-float("inf")).float() + min[key] = torch.tensor(float("inf")).float() + + # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get + # surprises when rerunning the sampler. + first_batch = None + running_item_count = 0 # for online mean computation + for i, batch in enumerate( + tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") + ): + this_batch_size = batch.batch_size[0] + running_item_count += this_batch_size + if first_batch is None: + first_batch = deepcopy(batch) + for key, pattern in stats_patterns.items(): + batch[key] = batch[key].float() + # Numerically stable update step for mean computation. + batch_mean = einops.reduce(batch[key], pattern, "mean") + # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents + # the update step, N is the running item count, B is this batch size, x̄ is the running mean, + # and x is the current batch mean. Some rearrangement is then required to avoid risking + # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields + # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ + mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count + max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) + min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) + + if i == ceil(max_num_samples / batch_size) - 1: + break + + first_batch_ = None + running_item_count = 0 # for online std computation + for i, batch in enumerate(tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")): + this_batch_size = batch.batch_size[0] + running_item_count += this_batch_size + # Sanity check to make sure the batches are still in the same order as before. + if first_batch_ is None: + first_batch_ = deepcopy(batch) + for key in stats_patterns: + assert torch.equal(first_batch_[key], first_batch[key]) + for key, pattern in stats_patterns.items(): + batch[key] = batch[key].float() + # Numerically stable update step for mean computation (where the mean is over squared + # residuals).See notes in the mean computation loop above. + batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") + std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count + + if i == ceil(max_num_samples / batch_size) - 1: + break + + for key in stats_patterns: + std[key] = torch.sqrt(std[key]) + + stats = {} + for key in stats_patterns: + stats[key] = { + "mean": mean[key], + "std": std[key], + "max": max[key], + "min": min[key], + } + + torch.save(stats, stats_path) + return stats diff --git a/lerobot/common/transforms.py b/lerobot/common/transforms.py index ec967614..4974c086 100644 --- a/lerobot/common/transforms.py +++ b/lerobot/common/transforms.py @@ -72,12 +72,12 @@ class NormalizeTransform(Transform): if inkey not in item: continue if self.mode == "mean_std": - mean = self.stats[inkey]["mean"] - std = self.stats[inkey]["std"] + mean = self.stats[f"{inkey}.mean"] + std = self.stats[f"{inkey}.std"] item[outkey] = (item[inkey] - mean) / (std + 1e-8) else: - min = self.stats[inkey]["min"] - max = self.stats[inkey]["max"] + min = self.stats[f"{inkey}.min"] + max = self.stats[f"{inkey}.max"] # normalize to [0,1] item[outkey] = (item[inkey] - min) / (max - min) # normalize to [-1, 1] @@ -89,12 +89,12 @@ class NormalizeTransform(Transform): if inkey not in item: continue if self.mode == "mean_std": - mean = self.stats[inkey]["mean"] - std = self.stats[inkey]["std"] + mean = self.stats[f"{inkey}.mean"] + std = self.stats[f"{inkey}.std"] item[outkey] = item[inkey] * std + mean else: - min = self.stats[inkey]["min"] - max = self.stats[inkey]["max"] + min = self.stats[f"{inkey}.min"] + max = self.stats[f"{inkey}.max"] item[outkey] = (item[inkey] + 1) / 2 item[outkey] = item[outkey] * (max - min) + min return item diff --git a/tests/test_datasets.py b/tests/test_datasets.py index f7f80a42..00008259 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,10 +1,6 @@ -import einops import pytest import torch -from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement -from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.utils import init_hydra_config import logging from lerobot.common.datasets.factory import make_dataset @@ -52,32 +48,32 @@ def test_factory(env_name, dataset_id): logging.warning(f'Missing "next.done" key in dataset {dataset}.') -def test_compute_stats(): - """Check that the statistics are computed correctly according to the stats_patterns property. +# def test_compute_stats(): +# """Check that the statistics are computed correctly according to the stats_patterns property. - We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do - because we are working with a small dataset). - """ - cfg = init_hydra_config( - DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"] - ) - buffer = make_offline_buffer(cfg) - # Get all of the data. - all_data = TensorDictReplayBuffer( - storage=buffer._storage, - batch_size=len(buffer), - sampler=SamplerWithoutReplacement(), - ).sample().float() - # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched - # computation of the statistics. While doing this, we also make sure it works when we don't divide the - # dataset into even batches. - computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75)) - for k, pattern in buffer.stats_patterns.items(): - expected_mean = einops.reduce(all_data[k], pattern, "mean") - assert torch.allclose(computed_stats[k]["mean"], expected_mean) - assert torch.allclose( - computed_stats[k]["std"], - torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) - ) - assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min")) - assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max")) +# We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do +# because we are working with a small dataset). +# """ +# cfg = init_hydra_config( +# DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"] +# ) +# dataset = make_dataset(cfg) +# # Get all of the data. +# all_data = TensorDictReplayBuffer( +# storage=buffer._storage, +# batch_size=len(buffer), +# sampler=SamplerWithoutReplacement(), +# ).sample().float() +# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched +# # computation of the statistics. While doing this, we also make sure it works when we don't divide the +# # dataset into even batches. +# computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75)) +# for k, pattern in buffer.stats_patterns.items(): +# expected_mean = einops.reduce(all_data[k], pattern, "mean") +# assert torch.allclose(computed_stats[k]["mean"], expected_mean) +# assert torch.allclose( +# computed_stats[k]["std"], +# torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) +# ) +# assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min")) +# assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max")) From 5af00d0c1ee0aa3d9a90e6afe646474073ff5065 Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 5 Apr 2024 09:31:39 +0000 Subject: [PATCH 3/3] fix train.py, stats, eval.py (training is running) --- lerobot/common/datasets/aloha.py | 16 +++++++------ lerobot/common/datasets/pusht.py | 16 +++++++------ lerobot/common/datasets/simxarm.py | 16 +++++++------ lerobot/common/datasets/utils.py | 15 ++++++++---- .../diffusion/diffusion_unet_image_policy.py | 7 +++--- lerobot/common/policies/diffusion/policy.py | 3 ++- lerobot/common/transforms.py | 16 ++++++------- lerobot/scripts/eval.py | 20 +++++++--------- lerobot/scripts/train.py | 9 +++---- tests/scripts/mock_dataset.py | 24 +++++++++---------- tests/test_datasets.py | 6 +---- 11 files changed, 76 insertions(+), 72 deletions(-) diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 2744f595..102de08e 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -91,15 +91,17 @@ class AlohaDataset(torch.utils.data.Dataset): self.transform = transform self.delta_timestamps = delta_timestamps - data_dir = self.root / f"{self.dataset_id}" - if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): - self.data_dict = torch.load(data_dir / "data_dict.pth") - self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") + self.data_dir = self.root / f"{self.dataset_id}" + if (self.data_dir / "data_dict.pth").exists() and ( + self.data_dir / "data_ids_per_episode.pth" + ).exists(): + self.data_dict = torch.load(self.data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth") else: self._download_and_preproc_obsolete() - data_dir.mkdir(parents=True, exist_ok=True) - torch.save(self.data_dict, data_dir / "data_dict.pth") - torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") + self.data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, self.data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth") @property def num_samples(self) -> int: diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 3de70b1f..9b73b101 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -105,15 +105,17 @@ class PushtDataset(torch.utils.data.Dataset): self.transform = transform self.delta_timestamps = delta_timestamps - data_dir = self.root / f"{self.dataset_id}" - if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): - self.data_dict = torch.load(data_dir / "data_dict.pth") - self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") + self.data_dir = self.root / f"{self.dataset_id}" + if (self.data_dir / "data_dict.pth").exists() and ( + self.data_dir / "data_ids_per_episode.pth" + ).exists(): + self.data_dict = torch.load(self.data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth") else: self._download_and_preproc_obsolete() - data_dir.mkdir(parents=True, exist_ok=True) - torch.save(self.data_dict, data_dir / "data_dict.pth") - torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") + self.data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, self.data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth") @property def num_samples(self) -> int: diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 4b2c68ad..7bddf608 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -46,15 +46,17 @@ class SimxarmDataset(torch.utils.data.Dataset): self.transform = transform self.delta_timestamps = delta_timestamps - data_dir = self.root / f"{self.dataset_id}" - if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): - self.data_dict = torch.load(data_dir / "data_dict.pth") - self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") + self.data_dir = self.root / f"{self.dataset_id}" + if (self.data_dir / "data_dict.pth").exists() and ( + self.data_dir / "data_ids_per_episode.pth" + ).exists(): + self.data_dict = torch.load(self.data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth") else: self._download_and_preproc_obsolete() - data_dir.mkdir(parents=True, exist_ok=True) - torch.save(self.data_dict, data_dir / "data_dict.pth") - torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") + self.data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, self.data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth") @property def num_samples(self) -> int: diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 522227d7..6b207b4d 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -112,16 +112,19 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): if max_num_samples is None: max_num_samples = len(dataset) + else: + raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.") dataloader = torch.utils.data.DataLoader( dataset, num_workers=4, batch_size=batch_size, - shuffle=True, + shuffle=False, # pin_memory=cfg.device != "cpu", drop_last=False, ) + # these einops patterns will be used to aggregate batches and compute statistics stats_patterns = { "action": "b c -> c", "observation.state": "b c -> c", @@ -142,9 +145,9 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): first_batch = None running_item_count = 0 # for online mean computation for i, batch in enumerate( - tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") + tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") ): - this_batch_size = batch.batch_size[0] + this_batch_size = len(batch["index"]) running_item_count += this_batch_size if first_batch is None: first_batch = deepcopy(batch) @@ -166,8 +169,10 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None): first_batch_ = None running_item_count = 0 # for online std computation - for i, batch in enumerate(tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")): - this_batch_size = batch.batch_size[0] + for i, batch in enumerate( + tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std") + ): + this_batch_size = len(batch["index"]) running_item_count += this_batch_size # Sanity check to make sure the batches are still in the same order as before. if first_batch_ is None: diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py index 7719fdde..373e4b6c 100644 --- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py +++ b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py @@ -243,10 +243,9 @@ class DiffusionUnetImagePolicy(BaseImagePolicy): result = {"action": action, "action_pred": action_pred} return result - def compute_loss(self, batch): - assert "valid_mask" not in batch - nobs = batch["obs"] - nactions = batch["action"] + def compute_loss(self, obs_dict, action): + nobs = obs_dict + nactions = action batch_size = nactions.shape[0] horizon = nactions.shape[1] diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index a0fe0eba..de8796ab 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -157,7 +157,8 @@ class DiffusionPolicy(nn.Module): "image": batch["observation.image"], "agent_pos": batch["observation.state"], } - loss = self.diffusion.compute_loss(obs_dict) + action = batch["action"] + loss = self.diffusion.compute_loss(obs_dict, action) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( diff --git a/lerobot/common/transforms.py b/lerobot/common/transforms.py index 4974c086..ec967614 100644 --- a/lerobot/common/transforms.py +++ b/lerobot/common/transforms.py @@ -72,12 +72,12 @@ class NormalizeTransform(Transform): if inkey not in item: continue if self.mode == "mean_std": - mean = self.stats[f"{inkey}.mean"] - std = self.stats[f"{inkey}.std"] + mean = self.stats[inkey]["mean"] + std = self.stats[inkey]["std"] item[outkey] = (item[inkey] - mean) / (std + 1e-8) else: - min = self.stats[f"{inkey}.min"] - max = self.stats[f"{inkey}.max"] + min = self.stats[inkey]["min"] + max = self.stats[inkey]["max"] # normalize to [0,1] item[outkey] = (item[inkey] - min) / (max - min) # normalize to [-1, 1] @@ -89,12 +89,12 @@ class NormalizeTransform(Transform): if inkey not in item: continue if self.mode == "mean_std": - mean = self.stats[f"{inkey}.mean"] - std = self.stats[f"{inkey}.std"] + mean = self.stats[inkey]["mean"] + std = self.stats[inkey]["std"] item[outkey] = item[inkey] * std + mean else: - min = self.stats[f"{inkey}.min"] - max = self.stats[f"{inkey}.max"] + min = self.stats[inkey]["min"] + max = self.stats[inkey]["max"] item[outkey] = (item[inkey] + 1) / 2 item[outkey] = item[outkey] * (max - min) + min return item diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index fe0f7bb2..09399878 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -37,7 +37,6 @@ from pathlib import Path import einops import gymnasium as gym -import hydra import imageio import numpy as np import torch @@ -47,8 +46,8 @@ from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.factory import make_env from lerobot.common.logger import log_output_dir from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed from lerobot.common.transforms import apply_inverse_transform +from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed def write_video(video_path, stacked_frames, fps): @@ -92,9 +91,12 @@ def eval_policy( fps: int = 15, return_first_video: bool = False, transform: callable = None, + seed=None, ): if policy is not None: policy.eval() + device = "cpu" if policy is None else next(policy.parameters()).device + start = time.time() sum_rewards = [] max_rewards = [] @@ -125,11 +127,11 @@ def eval_policy( policy.reset() else: logging.warning( - f"Policy {policy} doesnt have a `reset` method. This find if the policy doesnt rely on an internal state during rollout." + f"Policy {policy} doesnt have a `reset` method. It is required if the policy relies on an internal state during rollout." ) # reset the environment - observation, info = env.reset(seed=cfg.seed) + observation, info = env.reset(seed=seed) maybe_render_frame(env) rewards = [] @@ -138,13 +140,12 @@ def eval_policy( done = torch.tensor([False for _ in env.envs]) step = 0 - do_rollout = True - while do_rollout: + while not done.all(): # apply transform to normalize the observations observation = preprocess_observation(observation, transform) # send observation to device/gpu - observation = {key: observation[key].to(cfg.device, non_blocking=True) for key in observation} + observation = {key: observation[key].to(device, non_blocking=True) for key in observation} # get the next action for the environment with torch.inference_mode(): @@ -180,10 +181,6 @@ def eval_policy( step += 1 - if done.all(): - do_rollout = False - break - rewards = torch.stack(rewards, dim=1) successes = torch.stack(successes, dim=1) dones = torch.stack(dones, dim=1) @@ -295,6 +292,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None): fps=cfg.env.fps, # TODO(rcadene): what should we do with the transform? transform=dataset.transform, + seed=cfg.seed, ) print(info["aggregated"]) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5e9cd361..602fa5ab 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -145,7 +145,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # ) logging.info("make_env") - env = make_env(cfg) + env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) logging.info("make_policy") policy = make_policy(cfg) @@ -173,12 +173,11 @@ def train(cfg: dict, out_dir=None, job_name=None): eval_info, first_video = eval_policy( env, policy, - num_episodes=cfg.eval_episodes, - max_steps=cfg.env.episode_length, return_first_video=True, video_dir=Path(out_dir) / "eval", save_video=True, transform=dataset.transform, + seed=cfg.seed, ) log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline) if cfg.wandb.enable: @@ -211,7 +210,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = policy.update(batch, step) + train_info = policy(batch, step) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: @@ -223,6 +222,8 @@ def train(cfg: dict, out_dir=None, job_name=None): step += 1 + raise NotImplementedError() + demo_buffer = dataset if cfg.policy.balanced_sampling else None online_step = 0 is_offline = False diff --git a/tests/scripts/mock_dataset.py b/tests/scripts/mock_dataset.py index d9c86464..044417aa 100644 --- a/tests/scripts/mock_dataset.py +++ b/tests/scripts/mock_dataset.py @@ -18,28 +18,26 @@ Example: import argparse import shutil -from tensordict import TensorDict from pathlib import Path +import torch + def mock_dataset(in_data_dir, out_data_dir, num_frames): in_data_dir = Path(in_data_dir) out_data_dir = Path(out_data_dir) - # load full dataset as a tensor dict - in_td_data = TensorDict.load_memmap(in_data_dir / "replay_buffer") + # copy the first `n` frames for each data key so that we have real data + in_data_dict = torch.load(in_data_dir / "data_dict.pth") + out_data_dict = {key: in_data_dict[key][:num_frames].clone() for key in in_data_dict} + torch.save(out_data_dict, out_data_dir / "data_dict.pth") - # use 1 frame to know the specification of the dataset - # and copy it over `n` frames in the test artifact directory - out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir / "replay_buffer") + # copy the full mapping between data_id and episode since it's small + in_ids_per_ep_path = in_data_dir / "data_ids_per_episode.pth" + out_ids_per_ep_path = out_data_dir / "data_ids_per_episode.pth" + shutil.copy(in_ids_per_ep_path, out_ids_per_ep_path) - # copy the first `n` frames so that we have real data - out_td_data[:num_frames] = in_td_data[:num_frames].clone() - - # make sure everything has been properly written - out_td_data.lock_() - - # copy the full statistics of dataset since it's pretty small + # copy the full statistics of dataset since it's small in_stats_path = in_data_dir / "stats.pth" out_stats_path = out_data_dir / "stats.pth" shutil.copy(in_stats_path, out_stats_path) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 00008259..e5ca0099 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -59,11 +59,7 @@ def test_factory(env_name, dataset_id): # ) # dataset = make_dataset(cfg) # # Get all of the data. -# all_data = TensorDictReplayBuffer( -# storage=buffer._storage, -# batch_size=len(buffer), -# sampler=SamplerWithoutReplacement(), -# ).sample().float() +# all_data = dataset.data_dict # # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched # # computation of the statistics. While doing this, we also make sure it works when we don't divide the # # dataset into even batches.