From a7ef4a6a33c88730df3a01007d4426b1fd7a16ca Mon Sep 17 00:00:00 2001 From: Cadene Date: Mon, 11 Mar 2024 09:37:56 +0000 Subject: [PATCH 1/4] fix bug in compute_stats for action normalization --- lerobot/common/datasets/abstract.py | 2 +- lerobot/common/datasets/aloha.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index af30cf8c..5db9de86 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -54,7 +54,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): return { ("observation", "state"): "b c -> 1 c", ("observation", "image"): "b c h w -> 1 c 1 1", - ("action"): "b c -> 1 c", + "action": "b c -> 1 c", } @property diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 7397327d..f9351e20 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -114,7 +114,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): def stats_patterns(self) -> dict: d = { ("observation", "state"): "b c -> 1 c", - ("action"): "b c -> 1 c", + "action": "b c -> 1 c", } for cam in CAMERAS[self.dataset_id]: d[("observation", "image", cam)] = "b c h w -> 1 c 1 1" From 816b2e9d639ff694be1b81a5308f381a18427363 Mon Sep 17 00:00:00 2001 From: Cadene Date: Mon, 11 Mar 2024 11:03:13 +0000 Subject: [PATCH 2/4] fix more bugs in normalization --- lerobot/common/datasets/abstract.py | 8 ++++++-- lerobot/common/datasets/aloha.py | 2 +- lerobot/common/datasets/factory.py | 6 ++++-- lerobot/scripts/eval.py | 2 +- lerobot/scripts/train.py | 4 ++-- 5 files changed, 14 insertions(+), 8 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 5db9de86..1e67e9d8 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -54,7 +54,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): return { ("observation", "state"): "b c -> 1 c", ("observation", "image"): "b c h w -> 1 c 1 1", - "action": "b c -> 1 c", + ("action",): "b c -> 1 c", } @property @@ -73,8 +73,12 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): def num_episodes(self) -> int: return len(self._storage._storage["episode"].unique()) + @property + def transform(self): + return self._transform + def set_transform(self, transform): - self.transform = transform + self._transform = transform def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: stats_path = self.data_dir / "stats.pth" diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index f9351e20..afc28b1c 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -114,7 +114,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): def stats_patterns(self) -> dict: d = { ("observation", "state"): "b c -> 1 c", - "action": "b c -> 1 c", + ("action",): "b c -> 1 c", } for cam in CAMERAS[self.dataset_id]: d[("observation", "image", cam)] = "b c h w -> 1 c 1 1" diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 5bb4b14f..29f40bc6 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -87,9 +87,11 @@ def make_offline_buffer( if normalize: # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec stats = offline_buffer.compute_or_load_stats() + + # we only normalize the state and action, since the images are usually normalized inside the model for now (except for tdmpc: see the following) in_keys = [("observation", "state"), ("action")] - if cfg.policy == "tdmpc": + if cfg.policy.name == "tdmpc": for key in offline_buffer.image_keys: # TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc in_keys.append(key) @@ -97,7 +99,7 @@ def make_offline_buffer( in_keys.append(("next", *key)) in_keys.append(("next", "observation", "state")) - if cfg.policy == "diffusion" and cfg.env.name == "pusht": + if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index c9338dca..6435310a 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -126,7 +126,7 @@ def eval(cfg: dict, out_dir=None): offline_buffer = make_offline_buffer(cfg) logging.info("make_env") - env = make_env(cfg, transform=offline_buffer._transform) + env = make_env(cfg, transform=offline_buffer.transform) if cfg.policy.pretrained_model_path: policy = make_policy(cfg) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index be3bef8b..f4b22604 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -142,11 +142,11 @@ def train(cfg: dict, out_dir=None, job_name=None): online_buffer = TensorDictReplayBuffer( storage=LazyMemmapStorage(100_000), sampler=online_sampler, - transform=offline_buffer._transform, + transform=offline_buffer.transform, ) logging.info("make_env") - env = make_env(cfg, transform=offline_buffer._transform) + env = make_env(cfg, transform=offline_buffer.transform) logging.info("make_policy") policy = make_policy(cfg) From ccd5dc5a4253b3f284a96b63896e7ed328d60b02 Mon Sep 17 00:00:00 2001 From: Cadene Date: Mon, 11 Mar 2024 12:33:15 +0000 Subject: [PATCH 3/4] fix training --- lerobot/common/datasets/abstract.py | 8 +++++++- lerobot/common/envs/factory.py | 10 +++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 1e67e9d8..0c0746e1 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -14,6 +14,8 @@ from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer +from rl.torchrl.envs.transforms.transforms import Compose + class AbstractExperienceReplay(TensorDictReplayBuffer): def __init__( @@ -78,7 +80,11 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): return self._transform def set_transform(self, transform): - self._transform = transform + if not isinstance(transform, Compose): + # required since torchrl calls `len(self._transform)` downstream + self._transform = Compose(transform) + else: + self._transform = transform def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: stats_path = self.data_dir / "stats.pth" diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 984b866a..184646cf 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,5 +1,7 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv +from rl.torchrl.envs.transforms.transforms import Compose, Transform + def make_env(cfg, transform=None): kwargs = { @@ -33,7 +35,13 @@ def make_env(cfg, transform=None): if transform is not None: # useful to add normalization - env.append_transform(transform) + if isinstance(transform, Compose): + for tf in transform: + env.append_transform(tf.clone()) + elif isinstance(transform, Transform): + env.append_transform(transform.clone()) + else: + raise NotImplementedError() return env From 84a1647c012e972cb453bc715b76fb6d0441190b Mon Sep 17 00:00:00 2001 From: Cadene Date: Mon, 11 Mar 2024 12:41:14 +0000 Subject: [PATCH 4/4] fix import --- lerobot/common/datasets/abstract.py | 3 +-- lerobot/common/envs/factory.py | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 0c0746e1..d6a51246 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -13,8 +13,7 @@ from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer - -from rl.torchrl.envs.transforms.transforms import Compose +from torchrl.envs.transforms.transforms import Compose class AbstractExperienceReplay(TensorDictReplayBuffer): diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 184646cf..c20ef441 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,6 +1,4 @@ -from torchrl.envs.transforms import StepCounter, TransformedEnv - -from rl.torchrl.envs.transforms.transforms import Compose, Transform +from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv def make_env(cfg, transform=None):