diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index af30cf8c..d6a51246 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -13,6 +13,7 @@ from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer +from torchrl.envs.transforms.transforms import Compose class AbstractExperienceReplay(TensorDictReplayBuffer): @@ -54,7 +55,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): return { ("observation", "state"): "b c -> 1 c", ("observation", "image"): "b c h w -> 1 c 1 1", - ("action"): "b c -> 1 c", + ("action",): "b c -> 1 c", } @property @@ -73,8 +74,16 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): def num_episodes(self) -> int: return len(self._storage._storage["episode"].unique()) + @property + def transform(self): + return self._transform + def set_transform(self, transform): - self.transform = transform + if not isinstance(transform, Compose): + # required since torchrl calls `len(self._transform)` downstream + self._transform = Compose(transform) + else: + self._transform = transform def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: stats_path = self.data_dir / "stats.pth" diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 7397327d..afc28b1c 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -114,7 +114,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): def stats_patterns(self) -> dict: d = { ("observation", "state"): "b c -> 1 c", - ("action"): "b c -> 1 c", + ("action",): "b c -> 1 c", } for cam in CAMERAS[self.dataset_id]: d[("observation", "image", cam)] = "b c h w -> 1 c 1 1" diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 5bb4b14f..29f40bc6 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -87,9 +87,11 @@ def make_offline_buffer( if normalize: # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec stats = offline_buffer.compute_or_load_stats() + + # we only normalize the state and action, since the images are usually normalized inside the model for now (except for tdmpc: see the following) in_keys = [("observation", "state"), ("action")] - if cfg.policy == "tdmpc": + if cfg.policy.name == "tdmpc": for key in offline_buffer.image_keys: # TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc in_keys.append(key) @@ -97,7 +99,7 @@ def make_offline_buffer( in_keys.append(("next", *key)) in_keys.append(("next", "observation", "state")) - if cfg.policy == "diffusion" and cfg.env.name == "pusht": + if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index dd8ab2f7..269009db 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,4 +1,4 @@ -from torchrl.envs.transforms import StepCounter, TransformedEnv +from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv def make_env(cfg, transform=None): @@ -33,7 +33,13 @@ def make_env(cfg, transform=None): if transform is not None: # useful to add normalization - env.append_transform(transform) + if isinstance(transform, Compose): + for tf in transform: + env.append_transform(tf.clone()) + elif isinstance(transform, Transform): + env.append_transform(transform.clone()) + else: + raise NotImplementedError() return env diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index c9338dca..6435310a 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -126,7 +126,7 @@ def eval(cfg: dict, out_dir=None): offline_buffer = make_offline_buffer(cfg) logging.info("make_env") - env = make_env(cfg, transform=offline_buffer._transform) + env = make_env(cfg, transform=offline_buffer.transform) if cfg.policy.pretrained_model_path: policy = make_policy(cfg) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index be3bef8b..f4b22604 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -142,11 +142,11 @@ def train(cfg: dict, out_dir=None, job_name=None): online_buffer = TensorDictReplayBuffer( storage=LazyMemmapStorage(100_000), sampler=online_sampler, - transform=offline_buffer._transform, + transform=offline_buffer.transform, ) logging.info("make_env") - env = make_env(cfg, transform=offline_buffer._transform) + env = make_env(cfg, transform=offline_buffer.transform) logging.info("make_policy") policy = make_policy(cfg)