Add simxarm back into tests

This commit is contained in:
Simon Alibert 2024-03-25 16:35:46 +01:00
parent d3adaf1379
commit 058ac991eb
18 changed files with 44 additions and 19 deletions

View File

@ -84,7 +84,7 @@ class SimxarmEnv(AbstractEnv):
else: else:
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)} obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
obs = TensorDict(obs, batch_size=[]) # obs = TensorDict(obs, batch_size=[])
return obs return obs
def _reset(self, tensordict: Optional[TensorDict] = None): def _reset(self, tensordict: Optional[TensorDict] = None):

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:10ec2f944de18f1a2aa3fc2f9a4185c03e3a5afc31148c85c98b58602ac4186e
size 800

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1a589cba6bf0dfce138110864b6509508a804d7ea5c519d0a3cd67c4a87fa2d0
size 200

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6afe7098f30bdc8564526517c085e62613f6cb67194153840567974cfa6f3815
size 400

View File

@ -0,0 +1 @@
{"action": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int32"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:dab3a9712c413c4bfcd91c645752ab981306b23d25bcd4da4c422911574673f7
size 50

View File

@ -0,0 +1 @@
{"reward": {"device": "cpu", "shape": [50], "dtype": "torch.float32"}, "done": {"device": "cpu", "shape": [50], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d6f9d1422ce3764e7253f70ed4da278f0c07fafef0d5386479f09d6b6b9d8259
size 1058400

View File

@ -0,0 +1 @@
{"image": {"device": "cpu", "shape": [50, 3, 84, 84], "dtype": "torch.uint8"}, "state": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:52e7c1a3c4fb2423b195e66ffee2c9e23b3ea0ad7c8bfc4dec30a35c65cadcbb
size 800

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c4dbe8ea1966e5cc6da6daf5704805b9b5810f4575de7016b8f6cb1da1d7bb8a
size 200

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8fca8ddbda3f7bb2f6e7553297c18f3ab8f8b73d64b5c9f81a3695ad9379d403
size 1058400

View File

@ -0,0 +1 @@
{"image": {"device": "cpu", "shape": [50, 3, 84, 84], "dtype": "torch.uint8"}, "state": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7b3e3e12896d553c208ee152f6d447c877c435e15d010c4a6171966d5b8a0c0b
size 800

Binary file not shown.

View File

@ -10,9 +10,8 @@ from .utils import DEVICE, init_config
"env_name,dataset_id", "env_name,dataset_id",
[ [
# TODO(rcadene): simxarm is depreciated for now # TODO(rcadene): simxarm is depreciated for now
# ("simxarm", "lift"), ("simxarm", "lift"),
("pusht", "pusht"), ("pusht", "pusht"),
# TODO(aliberts): add aloha when dataset is available on hub
("aloha", "sim_insertion_human"), ("aloha", "sim_insertion_human"),
("aloha", "sim_insertion_scripted"), ("aloha", "sim_insertion_scripted"),
("aloha", "sim_transfer_cube_human"), ("aloha", "sim_transfer_cube_human"),

View File

@ -39,19 +39,20 @@ def print_spec_rollout(env):
print("data from rollout:", simple_rollout(100)) print("data from rollout:", simple_rollout(100))
@pytest.mark.skip(reason="Simxarm is deprecated") # @pytest.mark.skip(reason="Simxarm is deprecated")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"task,from_pixels,pixels_only", "task,from_pixels,pixels_only",
[ [
("lift", False, False), ("lift", False, False),
("lift", True, False), ("lift", True, False),
("lift", True, True), ("lift", True, True),
("reach", False, False), # TODO(aliberts): Add simxarm other task or remove them completely from repo
("reach", True, False), # ("reach", False, False),
("push", False, False), # ("reach", True, False),
("push", True, False), # ("push", False, False),
("peg_in_box", False, False), # ("push", True, False),
("peg_in_box", True, False), # ("peg_in_box", False, False),
# ("peg_in_box", True, False),
], ],
) )
def test_simxarm(task, from_pixels, pixels_only): def test_simxarm(task, from_pixels, pixels_only):
@ -84,7 +85,7 @@ def test_pusht(from_pixels, pixels_only):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name", "env_name",
[ [
# "simxarm", "simxarm",
"pusht", "pusht",
"aloha", "aloha",
], ],

View File

@ -19,12 +19,13 @@ from .utils import DEVICE, init_config
[ [
("simxarm", "tdmpc", ["policy.mpc=true"]), ("simxarm", "tdmpc", ["policy.mpc=true"]),
("pusht", "tdmpc", ["policy.mpc=false"]), ("pusht", "tdmpc", ["policy.mpc=false"]),
("simxarm", "diffusion", []),
("pusht", "diffusion", []), ("pusht", "diffusion", []),
("aloha", "act", ["env.task=sim_insertion_scripted"]), ("aloha", "act", ["env.task=sim_insertion_scripted"]),
("aloha", "act", ["env.task=sim_insertion_human"]), ("aloha", "act", ["env.task=sim_insertion_human"]),
("aloha", "act", ["env.task=sim_transfer_cube_scripted"]), ("aloha", "act", ["env.task=sim_transfer_cube_scripted"]),
("aloha", "act", ["env.task=sim_transfer_cube_human"]), ("aloha", "act", ["env.task=sim_transfer_cube_human"]),
# TODO(aliberts): simxarm not working with diffusion
# ("simxarm", "diffusion", []),
], ],
) )
def test_concrete_policy(env_name, policy_name, extra_overrides): def test_concrete_policy(env_name, policy_name, extra_overrides):
@ -45,13 +46,6 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
# Check that we can make the policy object. # Check that we can make the policy object.
policy = make_policy(cfg) policy = make_policy(cfg)
# Check that we run select_actions and get the appropriate output. # Check that we run select_actions and get the appropriate output.
if env_name == "simxarm":
# TODO(rcadene): Not implemented
return
if policy_name == "tdmpc":
# TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this.
with open_dict(cfg):
cfg["n_obs_steps"] = 1
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)
env = make_env(cfg, transform=offline_buffer.transform) env = make_env(cfg, transform=offline_buffer.transform)