diff --git a/lerobot/common/envs/simxarm/env.py b/lerobot/common/envs/simxarm/env.py index 9c996139..9b08be6a 100644 --- a/lerobot/common/envs/simxarm/env.py +++ b/lerobot/common/envs/simxarm/env.py @@ -84,7 +84,7 @@ class SimxarmEnv(AbstractEnv): else: obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)} - obs = TensorDict(obs, batch_size=[]) + # obs = TensorDict(obs, batch_size=[]) return obs def _reset(self, tensordict: Optional[TensorDict] = None): diff --git a/tests/data/xarm_lift_medium/replay_buffer/action.memmap b/tests/data/xarm_lift_medium/replay_buffer/action.memmap new file mode 100644 index 00000000..c90afbe9 --- /dev/null +++ b/tests/data/xarm_lift_medium/replay_buffer/action.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:10ec2f944de18f1a2aa3fc2f9a4185c03e3a5afc31148c85c98b58602ac4186e +size 800 diff --git a/tests/data/xarm_lift_medium/replay_buffer/episode.memmap b/tests/data/xarm_lift_medium/replay_buffer/episode.memmap new file mode 100644 index 00000000..7924f028 --- /dev/null +++ b/tests/data/xarm_lift_medium/replay_buffer/episode.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a589cba6bf0dfce138110864b6509508a804d7ea5c519d0a3cd67c4a87fa2d0 +size 200 diff --git a/tests/data/xarm_lift_medium/replay_buffer/frame_id.memmap b/tests/data/xarm_lift_medium/replay_buffer/frame_id.memmap new file mode 100644 index 00000000..a633d346 --- /dev/null +++ b/tests/data/xarm_lift_medium/replay_buffer/frame_id.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6afe7098f30bdc8564526517c085e62613f6cb67194153840567974cfa6f3815 +size 400 diff --git a/tests/data/xarm_lift_medium/replay_buffer/meta.json b/tests/data/xarm_lift_medium/replay_buffer/meta.json new file mode 100644 index 00000000..33dc932c --- /dev/null +++ b/tests/data/xarm_lift_medium/replay_buffer/meta.json @@ -0,0 +1 @@ +{"action": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int32"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/done.memmap b/tests/data/xarm_lift_medium/replay_buffer/next/done.memmap new file mode 100644 index 00000000..cf5e9cca --- /dev/null +++ b/tests/data/xarm_lift_medium/replay_buffer/next/done.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dab3a9712c413c4bfcd91c645752ab981306b23d25bcd4da4c422911574673f7 +size 50 diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/meta.json b/tests/data/xarm_lift_medium/replay_buffer/next/meta.json new file mode 100644 index 00000000..d69cadad --- /dev/null +++ b/tests/data/xarm_lift_medium/replay_buffer/next/meta.json @@ -0,0 +1 @@ +{"reward": {"device": "cpu", "shape": [50], "dtype": "torch.float32"}, "done": {"device": "cpu", "shape": [50], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/observation/image.memmap b/tests/data/xarm_lift_medium/replay_buffer/next/observation/image.memmap new file mode 100644 index 00000000..462d0117 --- /dev/null +++ b/tests/data/xarm_lift_medium/replay_buffer/next/observation/image.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d6f9d1422ce3764e7253f70ed4da278f0c07fafef0d5386479f09d6b6b9d8259 +size 1058400 diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/observation/meta.json b/tests/data/xarm_lift_medium/replay_buffer/next/observation/meta.json new file mode 100644 index 00000000..b13b8ec9 --- /dev/null +++ b/tests/data/xarm_lift_medium/replay_buffer/next/observation/meta.json @@ -0,0 +1 @@ +{"image": {"device": "cpu", "shape": [50, 3, 84, 84], "dtype": "torch.uint8"}, "state": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/observation/state.memmap b/tests/data/xarm_lift_medium/replay_buffer/next/observation/state.memmap new file mode 100644 index 00000000..1dbe6024 --- /dev/null +++ b/tests/data/xarm_lift_medium/replay_buffer/next/observation/state.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52e7c1a3c4fb2423b195e66ffee2c9e23b3ea0ad7c8bfc4dec30a35c65cadcbb +size 800 diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/reward.memmap b/tests/data/xarm_lift_medium/replay_buffer/next/reward.memmap new file mode 100644 index 00000000..9ff5d5a1 --- /dev/null +++ b/tests/data/xarm_lift_medium/replay_buffer/next/reward.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4dbe8ea1966e5cc6da6daf5704805b9b5810f4575de7016b8f6cb1da1d7bb8a +size 200 diff --git a/tests/data/xarm_lift_medium/replay_buffer/observation/image.memmap b/tests/data/xarm_lift_medium/replay_buffer/observation/image.memmap new file mode 100644 index 00000000..c9416940 --- /dev/null +++ b/tests/data/xarm_lift_medium/replay_buffer/observation/image.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fca8ddbda3f7bb2f6e7553297c18f3ab8f8b73d64b5c9f81a3695ad9379d403 +size 1058400 diff --git a/tests/data/xarm_lift_medium/replay_buffer/observation/meta.json b/tests/data/xarm_lift_medium/replay_buffer/observation/meta.json new file mode 100644 index 00000000..b13b8ec9 --- /dev/null +++ b/tests/data/xarm_lift_medium/replay_buffer/observation/meta.json @@ -0,0 +1 @@ +{"image": {"device": "cpu", "shape": [50, 3, 84, 84], "dtype": "torch.uint8"}, "state": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium/replay_buffer/observation/state.memmap b/tests/data/xarm_lift_medium/replay_buffer/observation/state.memmap new file mode 100644 index 00000000..3bae16df --- /dev/null +++ b/tests/data/xarm_lift_medium/replay_buffer/observation/state.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b3e3e12896d553c208ee152f6d447c877c435e15d010c4a6171966d5b8a0c0b +size 800 diff --git a/tests/data/xarm_lift_medium/stats.pth b/tests/data/xarm_lift_medium/stats.pth new file mode 100644 index 00000000..0accffb0 Binary files /dev/null and b/tests/data/xarm_lift_medium/stats.pth differ diff --git a/tests/test_datasets.py b/tests/test_datasets.py index b7d1e6f8..c3fcfccd 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -10,9 +10,8 @@ from .utils import DEVICE, init_config "env_name,dataset_id", [ # TODO(rcadene): simxarm is depreciated for now - # ("simxarm", "lift"), + ("simxarm", "lift"), ("pusht", "pusht"), - # TODO(aliberts): add aloha when dataset is available on hub ("aloha", "sim_insertion_human"), ("aloha", "sim_insertion_scripted"), ("aloha", "sim_transfer_cube_human"), diff --git a/tests/test_envs.py b/tests/test_envs.py index 8931cf52..6535535e 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -39,19 +39,20 @@ def print_spec_rollout(env): print("data from rollout:", simple_rollout(100)) -@pytest.mark.skip(reason="Simxarm is deprecated") +# @pytest.mark.skip(reason="Simxarm is deprecated") @pytest.mark.parametrize( "task,from_pixels,pixels_only", [ ("lift", False, False), ("lift", True, False), ("lift", True, True), - ("reach", False, False), - ("reach", True, False), - ("push", False, False), - ("push", True, False), - ("peg_in_box", False, False), - ("peg_in_box", True, False), + # TODO(aliberts): Add simxarm other task or remove them completely from repo + # ("reach", False, False), + # ("reach", True, False), + # ("push", False, False), + # ("push", True, False), + # ("peg_in_box", False, False), + # ("peg_in_box", True, False), ], ) def test_simxarm(task, from_pixels, pixels_only): @@ -84,7 +85,7 @@ def test_pusht(from_pixels, pixels_only): @pytest.mark.parametrize( "env_name", [ - # "simxarm", + "simxarm", "pusht", "aloha", ], diff --git a/tests/test_policies.py b/tests/test_policies.py index 92508dac..cd08fc4e 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -19,12 +19,13 @@ from .utils import DEVICE, init_config [ ("simxarm", "tdmpc", ["policy.mpc=true"]), ("pusht", "tdmpc", ["policy.mpc=false"]), - ("simxarm", "diffusion", []), ("pusht", "diffusion", []), ("aloha", "act", ["env.task=sim_insertion_scripted"]), ("aloha", "act", ["env.task=sim_insertion_human"]), ("aloha", "act", ["env.task=sim_transfer_cube_scripted"]), ("aloha", "act", ["env.task=sim_transfer_cube_human"]), + # TODO(aliberts): simxarm not working with diffusion + # ("simxarm", "diffusion", []), ], ) def test_concrete_policy(env_name, policy_name, extra_overrides): @@ -45,13 +46,6 @@ def test_concrete_policy(env_name, policy_name, extra_overrides): # Check that we can make the policy object. policy = make_policy(cfg) # Check that we run select_actions and get the appropriate output. - if env_name == "simxarm": - # TODO(rcadene): Not implemented - return - if policy_name == "tdmpc": - # TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this. - with open_dict(cfg): - cfg["n_obs_steps"] = 1 offline_buffer = make_offline_buffer(cfg) env = make_env(cfg, transform=offline_buffer.transform)