diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 529bf6db..8295ed48 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -14,6 +14,8 @@ from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.envs.transforms.transforms import Compose +HF_USER = "lerobot" + class AbstractExperienceReplay(TensorDictReplayBuffer): def __init__( @@ -106,7 +108,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): if self.root is None: self.data_dir = Path( snapshot_download( - repo_id=f"cadene/{self.dataset_id}", repo_type="dataset", revision=self.version + repo_id=f"{HF_USER}/{self.dataset_id}", repo_type="dataset", revision=self.version ) ) else: diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index e891ccdd..7c0c9d44 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -84,7 +84,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): def __init__( self, dataset_id: str, - version: str | None = "v1.1", + version: str | None = "v1.2", batch_size: int = None, *, shuffle: bool = True, diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index a8a47da8..bcbb10b8 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -87,7 +87,7 @@ class PushtExperienceReplay(AbstractExperienceReplay): def __init__( self, dataset_id: str, - version: str | None = "v1.1", + version: str | None = "v1.2", batch_size: int = None, *, shuffle: bool = True, diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth index d41ac18c..87b18a24 100644 Binary files a/tests/data/aloha_sim_insertion_human/stats.pth and b/tests/data/aloha_sim_insertion_human/stats.pth differ diff --git a/tests/data/aloha_sim_insertion_scripted/stats.pth b/tests/data/aloha_sim_insertion_scripted/stats.pth index 4e1c1884..7d149ca4 100644 Binary files a/tests/data/aloha_sim_insertion_scripted/stats.pth and b/tests/data/aloha_sim_insertion_scripted/stats.pth differ diff --git a/tests/data/aloha_sim_transfer_cube_human/stats.pth b/tests/data/aloha_sim_transfer_cube_human/stats.pth index 844550e1..22f3e4d9 100644 Binary files a/tests/data/aloha_sim_transfer_cube_human/stats.pth and b/tests/data/aloha_sim_transfer_cube_human/stats.pth differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/stats.pth b/tests/data/aloha_sim_transfer_cube_scripted/stats.pth index 836b29cf..63465344 100644 Binary files a/tests/data/aloha_sim_transfer_cube_scripted/stats.pth and b/tests/data/aloha_sim_transfer_cube_scripted/stats.pth differ diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth index 039d5db3..d7107185 100644 Binary files a/tests/data/pusht/stats.pth and b/tests/data/pusht/stats.pth differ diff --git a/tests/test_policies.py b/tests/test_policies.py index e6cfdfbc..92508dac 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -22,6 +22,9 @@ from .utils import DEVICE, init_config ("simxarm", "diffusion", []), ("pusht", "diffusion", []), ("aloha", "act", ["env.task=sim_insertion_scripted"]), + ("aloha", "act", ["env.task=sim_insertion_human"]), + ("aloha", "act", ["env.task=sim_transfer_cube_scripted"]), + ("aloha", "act", ["env.task=sim_transfer_cube_human"]), ], ) def test_concrete_policy(env_name, policy_name, extra_overrides):