Add simxarm back into tests
This commit is contained in:
parent
d3adaf1379
commit
058ac991eb
|
@ -84,7 +84,7 @@ class SimxarmEnv(AbstractEnv):
|
||||||
else:
|
else:
|
||||||
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
|
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
|
||||||
|
|
||||||
obs = TensorDict(obs, batch_size=[])
|
# obs = TensorDict(obs, batch_size=[])
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:10ec2f944de18f1a2aa3fc2f9a4185c03e3a5afc31148c85c98b58602ac4186e
|
||||||
|
size 800
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:1a589cba6bf0dfce138110864b6509508a804d7ea5c519d0a3cd67c4a87fa2d0
|
||||||
|
size 200
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:6afe7098f30bdc8564526517c085e62613f6cb67194153840567974cfa6f3815
|
||||||
|
size 400
|
|
@ -0,0 +1 @@
|
||||||
|
{"action": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int32"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:dab3a9712c413c4bfcd91c645752ab981306b23d25bcd4da4c422911574673f7
|
||||||
|
size 50
|
|
@ -0,0 +1 @@
|
||||||
|
{"reward": {"device": "cpu", "shape": [50], "dtype": "torch.float32"}, "done": {"device": "cpu", "shape": [50], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:d6f9d1422ce3764e7253f70ed4da278f0c07fafef0d5386479f09d6b6b9d8259
|
||||||
|
size 1058400
|
|
@ -0,0 +1 @@
|
||||||
|
{"image": {"device": "cpu", "shape": [50, 3, 84, 84], "dtype": "torch.uint8"}, "state": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:52e7c1a3c4fb2423b195e66ffee2c9e23b3ea0ad7c8bfc4dec30a35c65cadcbb
|
||||||
|
size 800
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:c4dbe8ea1966e5cc6da6daf5704805b9b5810f4575de7016b8f6cb1da1d7bb8a
|
||||||
|
size 200
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:8fca8ddbda3f7bb2f6e7553297c18f3ab8f8b73d64b5c9f81a3695ad9379d403
|
||||||
|
size 1058400
|
|
@ -0,0 +1 @@
|
||||||
|
{"image": {"device": "cpu", "shape": [50, 3, 84, 84], "dtype": "torch.uint8"}, "state": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:7b3e3e12896d553c208ee152f6d447c877c435e15d010c4a6171966d5b8a0c0b
|
||||||
|
size 800
|
Binary file not shown.
|
@ -10,9 +10,8 @@ from .utils import DEVICE, init_config
|
||||||
"env_name,dataset_id",
|
"env_name,dataset_id",
|
||||||
[
|
[
|
||||||
# TODO(rcadene): simxarm is depreciated for now
|
# TODO(rcadene): simxarm is depreciated for now
|
||||||
# ("simxarm", "lift"),
|
("simxarm", "lift"),
|
||||||
("pusht", "pusht"),
|
("pusht", "pusht"),
|
||||||
# TODO(aliberts): add aloha when dataset is available on hub
|
|
||||||
("aloha", "sim_insertion_human"),
|
("aloha", "sim_insertion_human"),
|
||||||
("aloha", "sim_insertion_scripted"),
|
("aloha", "sim_insertion_scripted"),
|
||||||
("aloha", "sim_transfer_cube_human"),
|
("aloha", "sim_transfer_cube_human"),
|
||||||
|
|
|
@ -39,19 +39,20 @@ def print_spec_rollout(env):
|
||||||
print("data from rollout:", simple_rollout(100))
|
print("data from rollout:", simple_rollout(100))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Simxarm is deprecated")
|
# @pytest.mark.skip(reason="Simxarm is deprecated")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"task,from_pixels,pixels_only",
|
"task,from_pixels,pixels_only",
|
||||||
[
|
[
|
||||||
("lift", False, False),
|
("lift", False, False),
|
||||||
("lift", True, False),
|
("lift", True, False),
|
||||||
("lift", True, True),
|
("lift", True, True),
|
||||||
("reach", False, False),
|
# TODO(aliberts): Add simxarm other task or remove them completely from repo
|
||||||
("reach", True, False),
|
# ("reach", False, False),
|
||||||
("push", False, False),
|
# ("reach", True, False),
|
||||||
("push", True, False),
|
# ("push", False, False),
|
||||||
("peg_in_box", False, False),
|
# ("push", True, False),
|
||||||
("peg_in_box", True, False),
|
# ("peg_in_box", False, False),
|
||||||
|
# ("peg_in_box", True, False),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_simxarm(task, from_pixels, pixels_only):
|
def test_simxarm(task, from_pixels, pixels_only):
|
||||||
|
@ -84,7 +85,7 @@ def test_pusht(from_pixels, pixels_only):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name",
|
"env_name",
|
||||||
[
|
[
|
||||||
# "simxarm",
|
"simxarm",
|
||||||
"pusht",
|
"pusht",
|
||||||
"aloha",
|
"aloha",
|
||||||
],
|
],
|
||||||
|
|
|
@ -19,12 +19,13 @@ from .utils import DEVICE, init_config
|
||||||
[
|
[
|
||||||
("simxarm", "tdmpc", ["policy.mpc=true"]),
|
("simxarm", "tdmpc", ["policy.mpc=true"]),
|
||||||
("pusht", "tdmpc", ["policy.mpc=false"]),
|
("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||||
("simxarm", "diffusion", []),
|
|
||||||
("pusht", "diffusion", []),
|
("pusht", "diffusion", []),
|
||||||
("aloha", "act", ["env.task=sim_insertion_scripted"]),
|
("aloha", "act", ["env.task=sim_insertion_scripted"]),
|
||||||
("aloha", "act", ["env.task=sim_insertion_human"]),
|
("aloha", "act", ["env.task=sim_insertion_human"]),
|
||||||
("aloha", "act", ["env.task=sim_transfer_cube_scripted"]),
|
("aloha", "act", ["env.task=sim_transfer_cube_scripted"]),
|
||||||
("aloha", "act", ["env.task=sim_transfer_cube_human"]),
|
("aloha", "act", ["env.task=sim_transfer_cube_human"]),
|
||||||
|
# TODO(aliberts): simxarm not working with diffusion
|
||||||
|
# ("simxarm", "diffusion", []),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_concrete_policy(env_name, policy_name, extra_overrides):
|
def test_concrete_policy(env_name, policy_name, extra_overrides):
|
||||||
|
@ -45,13 +46,6 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
|
||||||
# Check that we can make the policy object.
|
# Check that we can make the policy object.
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
# Check that we run select_actions and get the appropriate output.
|
# Check that we run select_actions and get the appropriate output.
|
||||||
if env_name == "simxarm":
|
|
||||||
# TODO(rcadene): Not implemented
|
|
||||||
return
|
|
||||||
if policy_name == "tdmpc":
|
|
||||||
# TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this.
|
|
||||||
with open_dict(cfg):
|
|
||||||
cfg["n_obs_steps"] = 1
|
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
env = make_env(cfg, transform=offline_buffer.transform)
|
env = make_env(cfg, transform=offline_buffer.transform)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue