diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 34b33c2e..4ce447bf 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -49,9 +49,9 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): @property def stats_patterns(self) -> dict: return { - ("observation", "state"): "b c -> 1 c", - ("observation", "image"): "b c h w -> 1 c 1 1", - ("action",): "b c -> 1 c", + ("observation", "state"): "b c -> c", + ("observation", "image"): "b c h w -> c 1 1", + ("action",): "b c -> c", } @property diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 52a5676e..b1a5806f 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -113,11 +113,11 @@ class AlohaExperienceReplay(AbstractExperienceReplay): @property def stats_patterns(self) -> dict: d = { - ("observation", "state"): "b c -> 1 c", - ("action",): "b c -> 1 c", + ("observation", "state"): "b c -> c", + ("action",): "b c -> c", } for cam in CAMERAS[self.dataset_id]: - d[("observation", "image", cam)] = "b c h w -> 1 c 1 1" + d[("observation", "image", cam)] = "b c h w -> c 1 1" return d @property diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth index 869d26cd..d7fc9495 100644 Binary files a/tests/data/aloha_sim_insertion_human/stats.pth and b/tests/data/aloha_sim_insertion_human/stats.pth differ diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth index 037e02f0..9b460898 100644 Binary files a/tests/data/pusht/stats.pth and b/tests/data/pusht/stats.pth differ