From d4e08499706f00f5c6c19413c00d5df24158bb01 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Tue, 5 Mar 2024 10:20:57 +0000 Subject: [PATCH] Refactor datasets with abstract class --- lerobot/common/datasets/abstract.py | 185 ++++++++++++++++++++++ lerobot/common/datasets/factory.py | 83 +++++----- lerobot/common/datasets/pusht.py | 228 +++------------------------- lerobot/common/datasets/simxarm.py | 117 +++----------- 4 files changed, 262 insertions(+), 351 deletions(-) create mode 100644 lerobot/common/datasets/abstract.py diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py new file mode 100644 index 00000000..0407c4f6 --- /dev/null +++ b/lerobot/common/datasets/abstract.py @@ -0,0 +1,185 @@ +import abc +import logging +import math +from pathlib import Path +from typing import Callable + +import einops +import torch +import torchrl +import tqdm +from tensordict import TensorDict +from torchrl.data.datasets.utils import _get_root_dir +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SliceSampler +from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id +from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer + + +class AbstractExperienceReplay(TensorDictReplayBuffer): + def __init__( + self, + dataset_id: str, + batch_size: int = None, + *, + shuffle: bool = True, + root: Path = None, + pin_memory: bool = False, + prefetch: int = None, + sampler: SliceSampler = None, + collate_fn: Callable = None, + writer: Writer = None, + transform: "torchrl.envs.Transform" = None, # noqa-F821 + ): + self.dataset_id = dataset_id + self.shuffle = shuffle + self.root = _get_root_dir(self.dataset_id) if root is None else root + self.root = Path(self.root) + self.data_dir = self.root / self.dataset_id + + storage = self._download_or_load_storage() + + super().__init__( + storage=storage, + sampler=sampler, + writer=ImmutableDatasetWriter() if writer is None else writer, + collate_fn=_collate_id if collate_fn is None else collate_fn, + pin_memory=pin_memory, + prefetch=prefetch, + batch_size=batch_size, + transform=transform, + ) + + @property + def num_samples(self) -> int: + return len(self) + + @property + def num_episodes(self) -> int: + return len(self._storage._storage["episode"].unique()) + + def set_transform(self, transform): + self.transform = transform + + def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: + stats_path = self.data_dir / "stats.pth" + if stats_path.exists(): + stats = torch.load(stats_path) + else: + logging.info(f"compute_stats and save to {stats_path}") + stats = self._compute_stats(self._storage._storage, num_batch, batch_size) + torch.save(stats, stats_path) + return stats + + @abc.abstractmethod + def _download_and_preproc(self) -> torch.StorageBase: + raise NotImplementedError() + + def _download_or_load_storage(self): + if not self._is_downloaded(): + storage = self._download_and_preproc() + else: + storage = TensorStorage(TensorDict.load_memmap(self.data_dir)) + return storage + + def _is_downloaded(self) -> bool: + return self.data_dir.is_dir() + + def _compute_stats(self, storage, num_batch=100, batch_size=32): + rb = TensorDictReplayBuffer( + storage=storage, + batch_size=batch_size, + prefetch=True, + ) + batch = rb.sample() + + image_channels = batch["observation", "image"].shape[1] + image_mean = torch.zeros(image_channels) + image_std = torch.zeros(image_channels) + image_max = torch.tensor([-math.inf] * image_channels) + image_min = torch.tensor([math.inf] * image_channels) + + state_channels = batch["observation", "state"].shape[1] + state_mean = torch.zeros(state_channels) + state_std = torch.zeros(state_channels) + state_max = torch.tensor([-math.inf] * state_channels) + state_min = torch.tensor([math.inf] * state_channels) + + action_channels = batch["action"].shape[1] + action_mean = torch.zeros(action_channels) + action_std = torch.zeros(action_channels) + action_max = torch.tensor([-math.inf] * action_channels) + action_min = torch.tensor([math.inf] * action_channels) + + for _ in tqdm.tqdm(range(num_batch)): + image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean") + state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean") + action_mean += einops.reduce(batch["action"], "b c -> c", "mean") + + b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max") + b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min") + b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max") + b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min") + b_action_max = einops.reduce(batch["action"], "b c -> c", "max") + b_action_min = einops.reduce(batch["action"], "b c -> c", "min") + image_max = torch.maximum(image_max, b_image_max) + image_min = torch.maximum(image_min, b_image_min) + state_max = torch.maximum(state_max, b_state_max) + state_min = torch.maximum(state_min, b_state_min) + action_max = torch.maximum(action_max, b_action_max) + action_min = torch.maximum(action_min, b_action_min) + + batch = rb.sample() + + image_mean /= num_batch + state_mean /= num_batch + action_mean /= num_batch + + for i in tqdm.tqdm(range(num_batch)): + b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean") + b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean") + b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean") + image_std += (b_image_mean - image_mean) ** 2 + state_std += (b_state_mean - state_mean) ** 2 + action_std += (b_action_mean - action_mean) ** 2 + + b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max") + b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min") + b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max") + b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min") + b_action_max = einops.reduce(batch["action"], "b c -> c", "max") + b_action_min = einops.reduce(batch["action"], "b c -> c", "min") + image_max = torch.maximum(image_max, b_image_max) + image_min = torch.maximum(image_min, b_image_min) + state_max = torch.maximum(state_max, b_state_max) + state_min = torch.maximum(state_min, b_state_min) + action_max = torch.maximum(action_max, b_action_max) + action_min = torch.maximum(action_min, b_action_min) + + if i < num_batch - 1: + batch = rb.sample() + + image_std = torch.sqrt(image_std / num_batch) + state_std = torch.sqrt(state_std / num_batch) + action_std = torch.sqrt(action_std / num_batch) + + stats = TensorDict( + { + ("observation", "image", "mean"): image_mean[None, :, None, None], + ("observation", "image", "std"): image_std[None, :, None, None], + ("observation", "image", "max"): image_max[None, :, None, None], + ("observation", "image", "min"): image_min[None, :, None, None], + ("observation", "state", "mean"): state_mean[None, :], + ("observation", "state", "std"): state_std[None, :], + ("observation", "state", "max"): state_max[None, :], + ("observation", "state", "min"): state_min[None, :], + ("action", "mean"): action_mean[None, :], + ("action", "std"): action_std[None, :], + ("action", "max"): action_max[None, :], + ("action", "min"): action_min[None, :], + }, + batch_size=[], + ) + stats["next", "observation", "image"] = stats["observation", "image"] + stats["next", "observation", "state"] = stats["observation", "state"] + return stats diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index e05fb926..feccdf21 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -5,31 +5,10 @@ from pathlib import Path import torch from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler -from lerobot.common.datasets.pusht import PushtExperienceReplay -from lerobot.common.datasets.simxarm import SimxarmExperienceReplay +from lerobot.common.envs.transforms import NormalizeTransform DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) -# TODO(rcadene): implement - -# dataset_d4rl = D4RLExperienceReplay( -# dataset_id="maze2d-umaze-v1", -# split_trajs=False, -# batch_size=1, -# sampler=SamplerWithoutReplacement(drop_last=False), -# prefetch=4, -# direct_download=True, -# ) - -# dataset_openx = OpenXExperienceReplay( -# "cmu_stretch", -# batch_size=1, -# num_slices=1, -# #download="force", -# streaming=False, -# root="data", -# ) - def make_offline_buffer(cfg, sampler=None): if cfg.policy.balanced_sampling: @@ -69,31 +48,49 @@ def make_offline_buffer(cfg, sampler=None): ) if cfg.env.name == "simxarm": - # TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here - offline_buffer = SimxarmExperienceReplay( - f"xarm_{cfg.env.task}_medium", - # download="force", - download=True, - streaming=False, - root=str(DATA_DIR), - sampler=sampler, - batch_size=batch_size, - pin_memory=pin_memory, - prefetch=prefetch if isinstance(prefetch, int) else None, - ) + from lerobot.common.datasets.simxarm import SimxarmExperienceReplay + + clsfunc = SimxarmExperienceReplay + dataset_id = f"xarm_{cfg.env.task}_medium" + elif cfg.env.name == "pusht": - offline_buffer = PushtExperienceReplay( - "pusht", - streaming=False, - root=DATA_DIR, - sampler=sampler, - batch_size=batch_size, - pin_memory=pin_memory, - prefetch=prefetch if isinstance(prefetch, int) else None, - ) + from lerobot.common.datasets.pusht import PushtExperienceReplay + + clsfunc = PushtExperienceReplay + dataset_id = "pusht" else: raise ValueError(cfg.env.name) + offline_buffer = clsfunc( + dataset_id=dataset_id, + root=DATA_DIR, + sampler=sampler, + batch_size=batch_size, + pin_memory=pin_memory, + prefetch=prefetch if isinstance(prefetch, int) else None, + ) + + # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec + stats = offline_buffer.compute_or_load_stats() + in_keys = [("observation", "state"), ("action")] + + if cfg.policy == "tdmpc": + # TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc + in_keys.append(("observation", "image")) + # since we use next observations in tdmpc + in_keys.append(("next", "observation", "image")) + in_keys.append(("next", "observation", "state")) + + if cfg.policy == "diffusion" and cfg.env.name == "pusht": + # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this + stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) + stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) + stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) + stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) + + transform = NormalizeTransform(stats, in_keys, mode="min_max") + offline_buffer.set_transform(transform) + if not overwrite_sampler: num_steps = len(offline_buffer) index = torch.arange(0, num_steps, 1) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 11569ee2..e086df27 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -1,6 +1,3 @@ -import logging -import math -import os from pathlib import Path from typing import Callable @@ -12,16 +9,14 @@ import torch import torchrl import tqdm from tensordict import TensorDict -from torchrl.data.datasets.utils import _get_root_dir -from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import Sampler -from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id -from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer +from torchrl.data.replay_buffers.samplers import SliceSampler +from torchrl.data.replay_buffers.storages import TensorStorage +from torchrl.data.replay_buffers.writers import Writer from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely +from lerobot.common.datasets.abstract import AbstractExperienceReplay from lerobot.common.datasets.utils import download_and_extract_zip -from lerobot.common.envs.transforms import NormalizeTransform # as define in env SUCCESS_THRESHOLD = 0.95 # 95% coverage, @@ -87,114 +82,36 @@ def add_tee( return body -class PushtExperienceReplay(TensorDictReplayBuffer): +class PushtExperienceReplay(AbstractExperienceReplay): def __init__( self, dataset_id: str, batch_size: int = None, *, shuffle: bool = True, - num_slices: int = None, - slice_len: int = None, - pad: float = None, - replacement: bool = None, - streaming: bool = False, root: Path = None, - sampler: Sampler = None, - writer: Writer = None, - collate_fn: Callable = None, pin_memory: bool = False, prefetch: int = None, - transform: "torchrl.envs.Transform" = None, # noqa: F821 - split_trajs: bool = False, - strict_length: bool = True, + sampler: SliceSampler = None, + collate_fn: Callable = None, + writer: Writer = None, + transform: "torchrl.envs.Transform" = None, # noqa-F821 ): - if streaming: - raise NotImplementedError - self.streaming = streaming - self.dataset_id = dataset_id - self.split_trajs = split_trajs - self.shuffle = shuffle - self.num_slices = num_slices - self.slice_len = slice_len - self.pad = pad - - self.strict_length = strict_length - if (self.num_slices is not None) and (self.slice_len is not None): - raise ValueError("num_slices or slice_len can be not None, but not both.") - if split_trajs: - raise NotImplementedError - - if root is None: - root = _get_root_dir("pusht") - os.makedirs(root, exist_ok=True) - - self.root = root - if not self._is_downloaded(): - storage = self._download_and_preproc() - else: - storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) - - stats = self._compute_or_load_stats(storage) - transform = NormalizeTransform( - stats, - in_keys=[ - # TODO(rcadene): imagenet normalization is applied inside diffusion policy - # We need to automate this for tdmpc and others - # ("observation", "image"), - ("observation", "state"), - # TODO(rcadene): for tdmpc, we might want next image and state - # ("next", "observation", "image"), - # ("next", "observation", "state"), - ("action"), - ], - mode="min_max", - ) - - # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec - transform.stats["observation", "state", "min"] = torch.tensor( - [13.456424, 32.938293], dtype=torch.float32 - ) - transform.stats["observation", "state", "max"] = torch.tensor( - [496.14618, 510.9579], dtype=torch.float32 - ) - transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) - transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) - - if writer is None: - writer = ImmutableDatasetWriter() - if collate_fn is None: - collate_fn = _collate_id - super().__init__( - storage=storage, - sampler=sampler, - writer=writer, - collate_fn=collate_fn, + dataset_id, + batch_size, + shuffle=shuffle, + root=root, pin_memory=pin_memory, prefetch=prefetch, - batch_size=batch_size, + sampler=sampler, + collate_fn=collate_fn, + writer=writer, transform=transform, ) - @property - def num_samples(self) -> int: - return len(self) - - @property - def num_episodes(self) -> int: - return len(self._storage._storage["episode"].unique()) - - @property - def data_path_root(self) -> Path: - return None if self.streaming else self.root / self.dataset_id - - def _is_downloaded(self) -> bool: - return self.data_path_root.is_dir() - def _download_and_preproc(self): - # download - raw_dir = self.root / "raw" + raw_dir = self.data_dir / "raw" zarr_path = (raw_dir / PUSHT_ZARR).resolve() if not zarr_path.is_dir(): raw_dir.mkdir(parents=True, exist_ok=True) @@ -286,7 +203,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id) + td_data = episode[0].expand(total_frames).memmap_like(self.data_dir) td_data[idxtd : idxtd + len(episode)] = episode @@ -294,112 +211,3 @@ class PushtExperienceReplay(TensorDictReplayBuffer): idxtd = idxtd + len(episode) return TensorStorage(td_data.lock_()) - - def _compute_stats(self, storage, num_batch=100, batch_size=32): - rb = TensorDictReplayBuffer( - storage=storage, - batch_size=batch_size, - prefetch=True, - ) - batch = rb.sample() - - image_channels = batch["observation", "image"].shape[1] - image_mean = torch.zeros(image_channels) - image_std = torch.zeros(image_channels) - image_max = torch.tensor([-math.inf] * image_channels) - image_min = torch.tensor([math.inf] * image_channels) - - state_channels = batch["observation", "state"].shape[1] - state_mean = torch.zeros(state_channels) - state_std = torch.zeros(state_channels) - state_max = torch.tensor([-math.inf] * state_channels) - state_min = torch.tensor([math.inf] * state_channels) - - action_channels = batch["action"].shape[1] - action_mean = torch.zeros(action_channels) - action_std = torch.zeros(action_channels) - action_max = torch.tensor([-math.inf] * action_channels) - action_min = torch.tensor([math.inf] * action_channels) - - for _ in tqdm.tqdm(range(num_batch)): - image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean") - state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean") - action_mean += einops.reduce(batch["action"], "b c -> c", "mean") - - b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max") - b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min") - b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max") - b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min") - b_action_max = einops.reduce(batch["action"], "b c -> c", "max") - b_action_min = einops.reduce(batch["action"], "b c -> c", "min") - image_max = torch.maximum(image_max, b_image_max) - image_min = torch.maximum(image_min, b_image_min) - state_max = torch.maximum(state_max, b_state_max) - state_min = torch.maximum(state_min, b_state_min) - action_max = torch.maximum(action_max, b_action_max) - action_min = torch.maximum(action_min, b_action_min) - - batch = rb.sample() - - image_mean /= num_batch - state_mean /= num_batch - action_mean /= num_batch - - for i in tqdm.tqdm(range(num_batch)): - b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean") - b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean") - b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean") - image_std += (b_image_mean - image_mean) ** 2 - state_std += (b_state_mean - state_mean) ** 2 - action_std += (b_action_mean - action_mean) ** 2 - - b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max") - b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min") - b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max") - b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min") - b_action_max = einops.reduce(batch["action"], "b c -> c", "max") - b_action_min = einops.reduce(batch["action"], "b c -> c", "min") - image_max = torch.maximum(image_max, b_image_max) - image_min = torch.maximum(image_min, b_image_min) - state_max = torch.maximum(state_max, b_state_max) - state_min = torch.maximum(state_min, b_state_min) - action_max = torch.maximum(action_max, b_action_max) - action_min = torch.maximum(action_min, b_action_min) - - if i < num_batch - 1: - batch = rb.sample() - - image_std = torch.sqrt(image_std / num_batch) - state_std = torch.sqrt(state_std / num_batch) - action_std = torch.sqrt(action_std / num_batch) - - stats = TensorDict( - { - ("observation", "image", "mean"): image_mean[None, :, None, None], - ("observation", "image", "std"): image_std[None, :, None, None], - ("observation", "image", "max"): image_max[None, :, None, None], - ("observation", "image", "min"): image_min[None, :, None, None], - ("observation", "state", "mean"): state_mean[None, :], - ("observation", "state", "std"): state_std[None, :], - ("observation", "state", "max"): state_max[None, :], - ("observation", "state", "min"): state_min[None, :], - ("action", "mean"): action_mean[None, :], - ("action", "std"): action_std[None, :], - ("action", "max"): action_max[None, :], - ("action", "min"): action_min[None, :], - }, - batch_size=[], - ) - stats["next", "observation", "image"] = stats["observation", "image"] - stats["next", "observation", "state"] = stats["observation", "state"] - return stats - - def _compute_or_load_stats(self, storage) -> TensorDict: - stats_path = self.root / self.dataset_id / "stats.pth" - if stats_path.exists(): - stats = torch.load(stats_path) - else: - logging.info(f"compute_stats and save to {stats_path}") - stats = self._compute_stats(storage) - torch.save(stats, stats_path) - return stats diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 1a8ba4de..784242cc 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -1,4 +1,3 @@ -import os import pickle from pathlib import Path from typing import Callable @@ -7,130 +6,52 @@ import torch import torchrl import tqdm from tensordict import TensorDict -from torchrl.data.datasets.utils import _get_root_dir -from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import ( - Sampler, SliceSampler, - SliceSamplerWithoutReplacement, ) -from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id -from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer +from torchrl.data.replay_buffers.storages import TensorStorage +from torchrl.data.replay_buffers.writers import Writer + +from lerobot.common.datasets.abstract import AbstractExperienceReplay -class SimxarmExperienceReplay(TensorDictReplayBuffer): +class SimxarmExperienceReplay(AbstractExperienceReplay): available_datasets = [ "xarm_lift_medium", ] def __init__( self, - dataset_id, + dataset_id: str, batch_size: int = None, *, shuffle: bool = True, - num_slices: int = None, - slice_len: int = None, - pad: float = None, - replacement: bool = None, - streaming: bool = False, root: Path = None, - download: bool = False, - sampler: Sampler = None, - writer: Writer = None, - collate_fn: Callable = None, pin_memory: bool = False, prefetch: int = None, - transform: "torchrl.envs.Transform" = None, # noqa-F821 - split_trajs: bool = False, - strict_length: bool = True, + sampler: SliceSampler = None, + collate_fn: Callable = None, + writer: Writer = None, + transform: "torchrl.envs.Transform" = None, ): - self.download = download - if streaming: - raise NotImplementedError - self.streaming = streaming - self.dataset_id = dataset_id - self.split_trajs = split_trajs - self.shuffle = shuffle - self.num_slices = num_slices - self.slice_len = slice_len - self.pad = pad - - self.strict_length = strict_length - if (self.num_slices is not None) and (self.slice_len is not None): - raise ValueError("num_slices or slice_len can be not None, but not both.") - if split_trajs: - raise NotImplementedError - - if root is None: - root = _get_root_dir("simxarm") - os.makedirs(root, exist_ok=True) - self.root = Path(root) - if self.download == "force" or (self.download and not self._is_downloaded()): - storage = self._download_and_preproc() - else: - storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) - - if num_slices is not None or slice_len is not None: - if sampler is not None: - raise ValueError("`num_slices` and `slice_len` are exclusive with the `sampler` argument.") - - if replacement: - if not self.shuffle: - raise RuntimeError("shuffle=False can only be used when replacement=False.") - sampler = SliceSampler( - num_slices=num_slices, - slice_len=slice_len, - strict_length=strict_length, - ) - else: - sampler = SliceSamplerWithoutReplacement( - num_slices=num_slices, - slice_len=slice_len, - strict_length=strict_length, - shuffle=self.shuffle, - ) - - if writer is None: - writer = ImmutableDatasetWriter() - if collate_fn is None: - collate_fn = _collate_id - super().__init__( - storage=storage, - sampler=sampler, - writer=writer, - collate_fn=collate_fn, + dataset_id, + batch_size, + shuffle=shuffle, + root=root, pin_memory=pin_memory, prefetch=prefetch, - batch_size=batch_size, + sampler=sampler, + collate_fn=collate_fn, + writer=writer, transform=transform, ) - @property - def num_samples(self): - return len(self) - - @property - def num_episodes(self): - return len(self._storage._storage["episode"].unique()) - - @property - def data_path_root(self): - if self.streaming: - return None - return self.root / self.dataset_id - - def _is_downloaded(self): - return os.path.exists(self.data_path_root) - def _download_and_preproc(self): # download # TODO(rcadene) - # load - dataset_dir = Path("data") / self.dataset_id - dataset_path = dataset_dir / "buffer.pkl" + dataset_path = self.data_dir / "buffer.pkl" print(f"Using offline dataset '{dataset_path}'") with open(dataset_path, "rb") as f: dataset_dict = pickle.load(f) @@ -172,7 +93,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id) + td_data = episode[0].expand(total_frames).memmap_like(self.data_dir) td_data[idx0:idx1] = episode