diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 5db9de86..1e67e9d8 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -54,7 +54,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): return { ("observation", "state"): "b c -> 1 c", ("observation", "image"): "b c h w -> 1 c 1 1", - "action": "b c -> 1 c", + ("action",): "b c -> 1 c", } @property @@ -73,8 +73,12 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): def num_episodes(self) -> int: return len(self._storage._storage["episode"].unique()) + @property + def transform(self): + return self._transform + def set_transform(self, transform): - self.transform = transform + self._transform = transform def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: stats_path = self.data_dir / "stats.pth" diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index f9351e20..afc28b1c 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -114,7 +114,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): def stats_patterns(self) -> dict: d = { ("observation", "state"): "b c -> 1 c", - "action": "b c -> 1 c", + ("action",): "b c -> 1 c", } for cam in CAMERAS[self.dataset_id]: d[("observation", "image", cam)] = "b c h w -> 1 c 1 1" diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 5bb4b14f..29f40bc6 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -87,9 +87,11 @@ def make_offline_buffer( if normalize: # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec stats = offline_buffer.compute_or_load_stats() + + # we only normalize the state and action, since the images are usually normalized inside the model for now (except for tdmpc: see the following) in_keys = [("observation", "state"), ("action")] - if cfg.policy == "tdmpc": + if cfg.policy.name == "tdmpc": for key in offline_buffer.image_keys: # TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc in_keys.append(key) @@ -97,7 +99,7 @@ def make_offline_buffer( in_keys.append(("next", *key)) in_keys.append(("next", "observation", "state")) - if cfg.policy == "diffusion" and cfg.env.name == "pusht": + if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index c9338dca..6435310a 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -126,7 +126,7 @@ def eval(cfg: dict, out_dir=None): offline_buffer = make_offline_buffer(cfg) logging.info("make_env") - env = make_env(cfg, transform=offline_buffer._transform) + env = make_env(cfg, transform=offline_buffer.transform) if cfg.policy.pretrained_model_path: policy = make_policy(cfg) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index be3bef8b..f4b22604 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -142,11 +142,11 @@ def train(cfg: dict, out_dir=None, job_name=None): online_buffer = TensorDictReplayBuffer( storage=LazyMemmapStorage(100_000), sampler=online_sampler, - transform=offline_buffer._transform, + transform=offline_buffer.transform, ) logging.info("make_env") - env = make_env(cfg, transform=offline_buffer._transform) + env = make_env(cfg, transform=offline_buffer.transform) logging.info("make_policy") policy = make_policy(cfg)