diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py new file mode 100644 index 00000000..af30cf8c --- /dev/null +++ b/lerobot/common/datasets/abstract.py @@ -0,0 +1,158 @@ +import abc +import logging +from pathlib import Path +from typing import Callable + +import einops +import torch +import torchrl +import tqdm +from tensordict import TensorDict +from torchrl.data.datasets.utils import _get_root_dir +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SliceSampler +from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id +from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer + + +class AbstractExperienceReplay(TensorDictReplayBuffer): + def __init__( + self, + dataset_id: str, + batch_size: int = None, + *, + shuffle: bool = True, + root: Path = None, + pin_memory: bool = False, + prefetch: int = None, + sampler: SliceSampler = None, + collate_fn: Callable = None, + writer: Writer = None, + transform: "torchrl.envs.Transform" = None, + ): + self.dataset_id = dataset_id + self.shuffle = shuffle + self.root = _get_root_dir(self.dataset_id) if root is None else root + self.root = Path(self.root) + self.data_dir = self.root / self.dataset_id + + storage = self._download_or_load_storage() + + super().__init__( + storage=storage, + sampler=sampler, + writer=ImmutableDatasetWriter() if writer is None else writer, + collate_fn=_collate_id if collate_fn is None else collate_fn, + pin_memory=pin_memory, + prefetch=prefetch, + batch_size=batch_size, + transform=transform, + ) + + @property + def stats_patterns(self) -> dict: + return { + ("observation", "state"): "b c -> 1 c", + ("observation", "image"): "b c h w -> 1 c 1 1", + ("action"): "b c -> 1 c", + } + + @property + def image_keys(self) -> list: + return [("observation", "image")] + + @property + def num_cameras(self) -> int: + return len(self.image_keys) + + @property + def num_samples(self) -> int: + return len(self) + + @property + def num_episodes(self) -> int: + return len(self._storage._storage["episode"].unique()) + + def set_transform(self, transform): + self.transform = transform + + def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: + stats_path = self.data_dir / "stats.pth" + if stats_path.exists(): + stats = torch.load(stats_path) + else: + logging.info(f"compute_stats and save to {stats_path}") + stats = self._compute_stats(num_batch, batch_size) + torch.save(stats, stats_path) + return stats + + @abc.abstractmethod + def _download_and_preproc(self) -> torch.StorageBase: + raise NotImplementedError() + + def _download_or_load_storage(self): + if not self._is_downloaded(): + storage = self._download_and_preproc() + else: + storage = TensorStorage(TensorDict.load_memmap(self.data_dir)) + return storage + + def _is_downloaded(self) -> bool: + return self.data_dir.is_dir() + + def _compute_stats(self, num_batch=100, batch_size=32): + rb = TensorDictReplayBuffer( + storage=self._storage, + batch_size=batch_size, + prefetch=True, + ) + + mean, std, max, min = {}, {}, {}, {} + + # compute mean, min, max + for _ in tqdm.tqdm(range(num_batch)): + batch = rb.sample() + for key, pattern in self.stats_patterns.items(): + batch[key] = batch[key].float() + if key not in mean: + # first batch initialize mean, min, max + mean[key] = einops.reduce(batch[key], pattern, "mean") + max[key] = einops.reduce(batch[key], pattern, "max") + min[key] = einops.reduce(batch[key], pattern, "min") + else: + mean[key] += einops.reduce(batch[key], pattern, "mean") + max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) + min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) + batch = rb.sample() + + for key in self.stats_patterns: + mean[key] /= num_batch + + # compute std, min, max + for _ in tqdm.tqdm(range(num_batch)): + batch = rb.sample() + for key, pattern in self.stats_patterns.items(): + batch[key] = batch[key].float() + batch_mean = einops.reduce(batch[key], pattern, "mean") + if key not in std: + # first batch initialize std + std[key] = (batch_mean - mean[key]) ** 2 + else: + std[key] += (batch_mean - mean[key]) ** 2 + max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) + min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) + + for key in self.stats_patterns: + std[key] = torch.sqrt(std[key] / num_batch) + + stats = TensorDict({}, batch_size=[]) + for key in self.stats_patterns: + stats[(*key, "mean")] = mean[key] + stats[(*key, "std")] = std[key] + stats[(*key, "max")] = max[key] + stats[(*key, "min")] = min[key] + + if key[0] == "observation": + # use same stats for the next observations + stats[("next", *key)] = stats[key] + return stats diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index e05fb926..e054682e 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -5,33 +5,14 @@ from pathlib import Path import torch from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler -from lerobot.common.datasets.pusht import PushtExperienceReplay -from lerobot.common.datasets.simxarm import SimxarmExperienceReplay +from lerobot.common.envs.transforms import NormalizeTransform DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) -# TODO(rcadene): implement -# dataset_d4rl = D4RLExperienceReplay( -# dataset_id="maze2d-umaze-v1", -# split_trajs=False, -# batch_size=1, -# sampler=SamplerWithoutReplacement(drop_last=False), -# prefetch=4, -# direct_download=True, -# ) - -# dataset_openx = OpenXExperienceReplay( -# "cmu_stretch", -# batch_size=1, -# num_slices=1, -# #download="force", -# streaming=False, -# root="data", -# ) - - -def make_offline_buffer(cfg, sampler=None): +def make_offline_buffer( + cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None +): if cfg.policy.balanced_sampling: assert cfg.online_steps > 0 batch_size = None @@ -44,9 +25,13 @@ def make_offline_buffer(cfg, sampler=None): pin_memory = cfg.device == "cuda" prefetch = cfg.prefetch - overwrite_sampler = sampler is not None + if overwrite_batch_size is not None: + batch_size = overwrite_batch_size - if not overwrite_sampler: + if overwrite_prefetch is not None: + prefetch = overwrite_prefetch + + if overwrite_sampler is None: # TODO(rcadene): move batch_size outside num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. @@ -67,36 +52,57 @@ def make_offline_buffer(cfg, sampler=None): num_slices=num_traj_per_batch, strict_length=False, ) + else: + sampler = overwrite_sampler if cfg.env.name == "simxarm": - # TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here - offline_buffer = SimxarmExperienceReplay( - f"xarm_{cfg.env.task}_medium", - # download="force", - download=True, - streaming=False, - root=str(DATA_DIR), - sampler=sampler, - batch_size=batch_size, - pin_memory=pin_memory, - prefetch=prefetch if isinstance(prefetch, int) else None, - ) + from lerobot.common.datasets.simxarm import SimxarmExperienceReplay + + clsfunc = SimxarmExperienceReplay + dataset_id = f"xarm_{cfg.env.task}_medium" + elif cfg.env.name == "pusht": - offline_buffer = PushtExperienceReplay( - "pusht", - streaming=False, - root=DATA_DIR, - sampler=sampler, - batch_size=batch_size, - pin_memory=pin_memory, - prefetch=prefetch if isinstance(prefetch, int) else None, - ) + from lerobot.common.datasets.pusht import PushtExperienceReplay + + clsfunc = PushtExperienceReplay + dataset_id = "pusht" else: raise ValueError(cfg.env.name) + offline_buffer = clsfunc( + dataset_id=dataset_id, + root=DATA_DIR, + sampler=sampler, + batch_size=batch_size, + pin_memory=pin_memory, + prefetch=prefetch if isinstance(prefetch, int) else None, + ) + + if normalize: + # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec + stats = offline_buffer.compute_or_load_stats() + in_keys = [("observation", "state"), ("action")] + + if cfg.policy == "tdmpc": + for key in offline_buffer.image_keys: + # TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc + in_keys.append(key) + # since we use next observations in tdmpc + in_keys.append(("next", *key)) + in_keys.append(("next", "observation", "state")) + + if cfg.policy == "diffusion" and cfg.env.name == "pusht": + # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this + stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) + stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) + stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) + stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) + + transform = NormalizeTransform(stats, in_keys, mode="min_max") + offline_buffer.set_transform(transform) + if not overwrite_sampler: - num_steps = len(offline_buffer) - index = torch.arange(0, num_steps, 1) + index = torch.arange(0, offline_buffer.num_samples, 1) sampler.extend(index) return offline_buffer diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 11569ee2..b93b519b 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -1,6 +1,3 @@ -import logging -import math -import os from pathlib import Path from typing import Callable @@ -12,16 +9,14 @@ import torch import torchrl import tqdm from tensordict import TensorDict -from torchrl.data.datasets.utils import _get_root_dir -from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import Sampler -from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id -from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer +from torchrl.data.replay_buffers.samplers import SliceSampler +from torchrl.data.replay_buffers.storages import TensorStorage +from torchrl.data.replay_buffers.writers import Writer from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely +from lerobot.common.datasets.abstract import AbstractExperienceReplay from lerobot.common.datasets.utils import download_and_extract_zip -from lerobot.common.envs.transforms import NormalizeTransform # as define in env SUCCESS_THRESHOLD = 0.95 # 95% coverage, @@ -87,114 +82,36 @@ def add_tee( return body -class PushtExperienceReplay(TensorDictReplayBuffer): +class PushtExperienceReplay(AbstractExperienceReplay): def __init__( self, dataset_id: str, batch_size: int = None, *, shuffle: bool = True, - num_slices: int = None, - slice_len: int = None, - pad: float = None, - replacement: bool = None, - streaming: bool = False, root: Path = None, - sampler: Sampler = None, - writer: Writer = None, - collate_fn: Callable = None, pin_memory: bool = False, prefetch: int = None, - transform: "torchrl.envs.Transform" = None, # noqa: F821 - split_trajs: bool = False, - strict_length: bool = True, + sampler: SliceSampler = None, + collate_fn: Callable = None, + writer: Writer = None, + transform: "torchrl.envs.Transform" = None, ): - if streaming: - raise NotImplementedError - self.streaming = streaming - self.dataset_id = dataset_id - self.split_trajs = split_trajs - self.shuffle = shuffle - self.num_slices = num_slices - self.slice_len = slice_len - self.pad = pad - - self.strict_length = strict_length - if (self.num_slices is not None) and (self.slice_len is not None): - raise ValueError("num_slices or slice_len can be not None, but not both.") - if split_trajs: - raise NotImplementedError - - if root is None: - root = _get_root_dir("pusht") - os.makedirs(root, exist_ok=True) - - self.root = root - if not self._is_downloaded(): - storage = self._download_and_preproc() - else: - storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) - - stats = self._compute_or_load_stats(storage) - transform = NormalizeTransform( - stats, - in_keys=[ - # TODO(rcadene): imagenet normalization is applied inside diffusion policy - # We need to automate this for tdmpc and others - # ("observation", "image"), - ("observation", "state"), - # TODO(rcadene): for tdmpc, we might want next image and state - # ("next", "observation", "image"), - # ("next", "observation", "state"), - ("action"), - ], - mode="min_max", - ) - - # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec - transform.stats["observation", "state", "min"] = torch.tensor( - [13.456424, 32.938293], dtype=torch.float32 - ) - transform.stats["observation", "state", "max"] = torch.tensor( - [496.14618, 510.9579], dtype=torch.float32 - ) - transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) - transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) - - if writer is None: - writer = ImmutableDatasetWriter() - if collate_fn is None: - collate_fn = _collate_id - super().__init__( - storage=storage, - sampler=sampler, - writer=writer, - collate_fn=collate_fn, + dataset_id, + batch_size, + shuffle=shuffle, + root=root, pin_memory=pin_memory, prefetch=prefetch, - batch_size=batch_size, + sampler=sampler, + collate_fn=collate_fn, + writer=writer, transform=transform, ) - @property - def num_samples(self) -> int: - return len(self) - - @property - def num_episodes(self) -> int: - return len(self._storage._storage["episode"].unique()) - - @property - def data_path_root(self) -> Path: - return None if self.streaming else self.root / self.dataset_id - - def _is_downloaded(self) -> bool: - return self.data_path_root.is_dir() - def _download_and_preproc(self): - # download - raw_dir = self.root / "raw" + raw_dir = self.data_dir / "raw" zarr_path = (raw_dir / PUSHT_ZARR).resolve() if not zarr_path.is_dir(): raw_dir.mkdir(parents=True, exist_ok=True) @@ -266,8 +183,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer): # last step of demonstration is considered done done[-1] = True - print("before " + """episode = TensorDict(""") - episode = TensorDict( + ep_td = TensorDict( { ("observation", "image"): image[:-1], ("observation", "state"): agent_pos[:-1], @@ -286,120 +202,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id) + td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir) - td_data[idxtd : idxtd + len(episode)] = episode + td_data[idxtd : idxtd + len(ep_td)] = ep_td idx0 = idx1 - idxtd = idxtd + len(episode) + idxtd = idxtd + len(ep_td) return TensorStorage(td_data.lock_()) - - def _compute_stats(self, storage, num_batch=100, batch_size=32): - rb = TensorDictReplayBuffer( - storage=storage, - batch_size=batch_size, - prefetch=True, - ) - batch = rb.sample() - - image_channels = batch["observation", "image"].shape[1] - image_mean = torch.zeros(image_channels) - image_std = torch.zeros(image_channels) - image_max = torch.tensor([-math.inf] * image_channels) - image_min = torch.tensor([math.inf] * image_channels) - - state_channels = batch["observation", "state"].shape[1] - state_mean = torch.zeros(state_channels) - state_std = torch.zeros(state_channels) - state_max = torch.tensor([-math.inf] * state_channels) - state_min = torch.tensor([math.inf] * state_channels) - - action_channels = batch["action"].shape[1] - action_mean = torch.zeros(action_channels) - action_std = torch.zeros(action_channels) - action_max = torch.tensor([-math.inf] * action_channels) - action_min = torch.tensor([math.inf] * action_channels) - - for _ in tqdm.tqdm(range(num_batch)): - image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean") - state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean") - action_mean += einops.reduce(batch["action"], "b c -> c", "mean") - - b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max") - b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min") - b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max") - b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min") - b_action_max = einops.reduce(batch["action"], "b c -> c", "max") - b_action_min = einops.reduce(batch["action"], "b c -> c", "min") - image_max = torch.maximum(image_max, b_image_max) - image_min = torch.maximum(image_min, b_image_min) - state_max = torch.maximum(state_max, b_state_max) - state_min = torch.maximum(state_min, b_state_min) - action_max = torch.maximum(action_max, b_action_max) - action_min = torch.maximum(action_min, b_action_min) - - batch = rb.sample() - - image_mean /= num_batch - state_mean /= num_batch - action_mean /= num_batch - - for i in tqdm.tqdm(range(num_batch)): - b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean") - b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean") - b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean") - image_std += (b_image_mean - image_mean) ** 2 - state_std += (b_state_mean - state_mean) ** 2 - action_std += (b_action_mean - action_mean) ** 2 - - b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max") - b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min") - b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max") - b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min") - b_action_max = einops.reduce(batch["action"], "b c -> c", "max") - b_action_min = einops.reduce(batch["action"], "b c -> c", "min") - image_max = torch.maximum(image_max, b_image_max) - image_min = torch.maximum(image_min, b_image_min) - state_max = torch.maximum(state_max, b_state_max) - state_min = torch.maximum(state_min, b_state_min) - action_max = torch.maximum(action_max, b_action_max) - action_min = torch.maximum(action_min, b_action_min) - - if i < num_batch - 1: - batch = rb.sample() - - image_std = torch.sqrt(image_std / num_batch) - state_std = torch.sqrt(state_std / num_batch) - action_std = torch.sqrt(action_std / num_batch) - - stats = TensorDict( - { - ("observation", "image", "mean"): image_mean[None, :, None, None], - ("observation", "image", "std"): image_std[None, :, None, None], - ("observation", "image", "max"): image_max[None, :, None, None], - ("observation", "image", "min"): image_min[None, :, None, None], - ("observation", "state", "mean"): state_mean[None, :], - ("observation", "state", "std"): state_std[None, :], - ("observation", "state", "max"): state_max[None, :], - ("observation", "state", "min"): state_min[None, :], - ("action", "mean"): action_mean[None, :], - ("action", "std"): action_std[None, :], - ("action", "max"): action_max[None, :], - ("action", "min"): action_min[None, :], - }, - batch_size=[], - ) - stats["next", "observation", "image"] = stats["observation", "image"] - stats["next", "observation", "state"] = stats["observation", "state"] - return stats - - def _compute_or_load_stats(self, storage) -> TensorDict: - stats_path = self.root / self.dataset_id / "stats.pth" - if stats_path.exists(): - stats = torch.load(stats_path) - else: - logging.info(f"compute_stats and save to {stats_path}") - stats = self._compute_stats(storage) - torch.save(stats, stats_path) - return stats diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 1a8ba4de..784242cc 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -1,4 +1,3 @@ -import os import pickle from pathlib import Path from typing import Callable @@ -7,130 +6,52 @@ import torch import torchrl import tqdm from tensordict import TensorDict -from torchrl.data.datasets.utils import _get_root_dir -from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import ( - Sampler, SliceSampler, - SliceSamplerWithoutReplacement, ) -from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id -from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer +from torchrl.data.replay_buffers.storages import TensorStorage +from torchrl.data.replay_buffers.writers import Writer + +from lerobot.common.datasets.abstract import AbstractExperienceReplay -class SimxarmExperienceReplay(TensorDictReplayBuffer): +class SimxarmExperienceReplay(AbstractExperienceReplay): available_datasets = [ "xarm_lift_medium", ] def __init__( self, - dataset_id, + dataset_id: str, batch_size: int = None, *, shuffle: bool = True, - num_slices: int = None, - slice_len: int = None, - pad: float = None, - replacement: bool = None, - streaming: bool = False, root: Path = None, - download: bool = False, - sampler: Sampler = None, - writer: Writer = None, - collate_fn: Callable = None, pin_memory: bool = False, prefetch: int = None, - transform: "torchrl.envs.Transform" = None, # noqa-F821 - split_trajs: bool = False, - strict_length: bool = True, + sampler: SliceSampler = None, + collate_fn: Callable = None, + writer: Writer = None, + transform: "torchrl.envs.Transform" = None, ): - self.download = download - if streaming: - raise NotImplementedError - self.streaming = streaming - self.dataset_id = dataset_id - self.split_trajs = split_trajs - self.shuffle = shuffle - self.num_slices = num_slices - self.slice_len = slice_len - self.pad = pad - - self.strict_length = strict_length - if (self.num_slices is not None) and (self.slice_len is not None): - raise ValueError("num_slices or slice_len can be not None, but not both.") - if split_trajs: - raise NotImplementedError - - if root is None: - root = _get_root_dir("simxarm") - os.makedirs(root, exist_ok=True) - self.root = Path(root) - if self.download == "force" or (self.download and not self._is_downloaded()): - storage = self._download_and_preproc() - else: - storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) - - if num_slices is not None or slice_len is not None: - if sampler is not None: - raise ValueError("`num_slices` and `slice_len` are exclusive with the `sampler` argument.") - - if replacement: - if not self.shuffle: - raise RuntimeError("shuffle=False can only be used when replacement=False.") - sampler = SliceSampler( - num_slices=num_slices, - slice_len=slice_len, - strict_length=strict_length, - ) - else: - sampler = SliceSamplerWithoutReplacement( - num_slices=num_slices, - slice_len=slice_len, - strict_length=strict_length, - shuffle=self.shuffle, - ) - - if writer is None: - writer = ImmutableDatasetWriter() - if collate_fn is None: - collate_fn = _collate_id - super().__init__( - storage=storage, - sampler=sampler, - writer=writer, - collate_fn=collate_fn, + dataset_id, + batch_size, + shuffle=shuffle, + root=root, pin_memory=pin_memory, prefetch=prefetch, - batch_size=batch_size, + sampler=sampler, + collate_fn=collate_fn, + writer=writer, transform=transform, ) - @property - def num_samples(self): - return len(self) - - @property - def num_episodes(self): - return len(self._storage._storage["episode"].unique()) - - @property - def data_path_root(self): - if self.streaming: - return None - return self.root / self.dataset_id - - def _is_downloaded(self): - return os.path.exists(self.data_path_root) - def _download_and_preproc(self): # download # TODO(rcadene) - # load - dataset_dir = Path("data") / self.dataset_id - dataset_path = dataset_dir / "buffer.pkl" + dataset_path = self.data_dir / "buffer.pkl" print(f"Using offline dataset '{dataset_path}'") with open(dataset_path, "rb") as f: dataset_dict = pickle.load(f) @@ -172,7 +93,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id) + td_data = episode[0].expand(total_frames).memmap_like(self.data_dir) td_data[idx0:idx1] = episode diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 54325bd4..3d98d726 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -6,6 +6,10 @@ from omegaconf import OmegaConf from termcolor import colored +def log_output_dir(out_dir): + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") + + def cfg_to_group(cfg, return_list=False): """Return a wandb-safe group name for logging. Optionally returns group name as list.""" # lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)] diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index abe4645a..c9338dca 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -9,13 +9,13 @@ import numpy as np import torch import tqdm from tensordict.nn import TensorDictModule -from termcolor import colored from torchrl.envs import EnvBase from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env +from lerobot.common.logger import log_output_dir from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import set_seed +from lerobot.common.utils import init_logging, set_seed def write_video(video_path, stacked_frames, fps): @@ -109,10 +109,18 @@ def eval(cfg: dict, out_dir=None): if out_dir is None: raise NotImplementedError() - assert torch.cuda.is_available() + init_logging() + + if cfg.device == "cuda": + assert torch.cuda.is_available() + else: + logging.warning("Using CPU, this will be slow.") + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True set_seed(cfg.seed) - print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir) + + log_output_dir(out_dir) logging.info("make_offline_buffer") offline_buffer = make_offline_buffer(cfg) @@ -142,6 +150,8 @@ def eval(cfg: dict, out_dir=None): ) print(metrics) + logging.info("End of eval") + if __name__ == "__main__": eval_cli() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index c169b49b..be3bef8b 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -4,13 +4,12 @@ import hydra import numpy as np import torch from tensordict.nn import TensorDictModule -from termcolor import colored from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers import PrioritizedSliceSampler from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env -from lerobot.common.logger import Logger +from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.utils import format_big_number, init_logging, set_seed from lerobot.scripts.eval import eval_policy @@ -164,7 +163,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # log metrics to terminal and wandb logger = Logger(out_dir, job_name, cfg) - logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}") + log_output_dir(out_dir) logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.online_steps=}") @@ -212,7 +211,6 @@ def train(cfg: dict, out_dir=None, job_name=None): for env_step in range(cfg.online_steps): if env_step == 0: logging.info("Start online training by interacting with environment") - # TODO: use SyncDataCollector for that? # TODO: add configurable number of rollout? (default=1) with torch.no_grad(): rollout = env.rollout( @@ -268,6 +266,8 @@ def train(cfg: dict, out_dir=None, job_name=None): step += 1 online_step += 1 + logging.info("End of training") + if __name__ == "__main__": train_cli() diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 5aa0a278..1bd63f6e 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -1,13 +1,20 @@ +import logging +import threading from pathlib import Path +import einops import hydra import imageio import torch -from torchrl.data.replay_buffers import SliceSamplerWithoutReplacement +from torchrl.data.replay_buffers import ( + SamplerWithoutReplacement, +) from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.logger import log_output_dir +from lerobot.common.utils import init_logging -NUM_EPISODES_TO_RENDER = 10 +NUM_EPISODES_TO_RENDER = 50 MAX_NUM_STEPS = 1000 FIRST_FRAME = 0 @@ -17,45 +24,88 @@ def visualize_dataset_cli(cfg: dict): visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) +def cat_and_write_video(video_path, frames, fps): + frames = torch.cat(frames) + assert frames.dtype == torch.uint8 + frames = einops.rearrange(frames, "b c h w -> b h w c").numpy() + imageio.mimsave(video_path, frames, fps=fps) + + def visualize_dataset(cfg: dict, out_dir=None): if out_dir is None: raise NotImplementedError() - sampler = SliceSamplerWithoutReplacement( - num_slices=1, - strict_length=False, + init_logging() + log_output_dir(out_dir) + + # we expect frames of each episode to be stored next to each others sequentially + sampler = SamplerWithoutReplacement( shuffle=False, ) - offline_buffer = make_offline_buffer(cfg, sampler) + logging.info("make_offline_buffer") + offline_buffer = make_offline_buffer( + cfg, overwrite_sampler=sampler, normalize=False, overwrite_batch_size=1, overwrite_prefetch=12 + ) - for _ in range(NUM_EPISODES_TO_RENDER): - episode = offline_buffer.sample(MAX_NUM_STEPS) + logging.info("Start rendering episodes from offline buffer") - ep_idx = episode["episode"][FIRST_FRAME].item() - ep_frames = torch.cat( - [ - episode["observation"]["image"][FIRST_FRAME][None, ...], - episode["next", "observation"]["image"], - ], - dim=0, - ) + threads = [] + frames = {} + current_ep_idx = 0 + logging.info(f"Visualizing episode {current_ep_idx}") + for _ in range(MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER): + # TODO(rcadene): make it work with bsize > 1 + ep_td = offline_buffer.sample(1) + ep_idx = ep_td["episode"][FIRST_FRAME].item() - video_dir = Path(out_dir) / "visualize_dataset" - video_dir.mkdir(parents=True, exist_ok=True) - # TODO(rcadene): make fps configurable - video_path = video_dir / f"episode_{ep_idx}.mp4" + # TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames + no_more_frames = offline_buffer._sampler._sample_list.numel() == 0 + new_episode = ep_idx != current_ep_idx - assert ep_frames.min().item() >= 0 - assert ep_frames.max().item() > 1, "Not mendatory, but sanity check" - assert ep_frames.max().item() <= 255 - ep_frames = ep_frames.type(torch.uint8) - imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps) + if new_episode: + logging.info(f"Visualizing episode {current_ep_idx}") - # ran out of episodes - if offline_buffer._sampler._sample_list.numel() == 0: + for im_key in offline_buffer.image_keys: + if new_episode or no_more_frames: + # append last observed frames (the ones after last action taken) + frames[im_key].append(ep_td[("next", *im_key)]) + + video_dir = Path(out_dir) / "visualize_dataset" + video_dir.mkdir(parents=True, exist_ok=True) + + if len(offline_buffer.image_keys) > 1: + camera = im_key[-1] + video_path = video_dir / f"episode_{current_ep_idx}_{camera}.mp4" + else: + video_path = video_dir / f"episode_{current_ep_idx}.mp4" + + thread = threading.Thread( + target=cat_and_write_video, + args=(str(video_path), frames[im_key], cfg.fps), + ) + thread.start() + threads.append(thread) + + current_ep_idx = ep_idx + + # reset list of frames + del frames[im_key] + + # append current cameras images to list of frames + if im_key not in frames: + frames[im_key] = [] + frames[im_key].append(ep_td[im_key]) + + if no_more_frames: + logging.info("Ran out of frames") break + for thread in threads: + thread.join() + + logging.info("End of visualize_dataset") + if __name__ == "__main__": visualize_dataset_cli()