From 6bddcb647e1f80e2cd37b971fb903a2820018f08 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 26 Mar 2024 16:23:34 +0100 Subject: [PATCH] Add test_aloha env test --- lerobot/common/envs/aloha/env.py | 2 +- tests/test_envs.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index 357a96ec..8f907650 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -206,7 +206,7 @@ class AlohaEnv(AbstractEnv): if self.from_pixels: if isinstance(self.image_size, int): image_shape = (3, self.image_size, self.image_size) - elif OmegaConf.is_list(self.image_size): + elif OmegaConf.is_list(self.image_size) or isinstance(self.image_size, list): assert len(self.image_size) == 3 # c h w assert self.image_size[0] == 3 # c is RGB image_shape = tuple(self.image_size) diff --git a/tests/test_envs.py b/tests/test_envs.py index 2beafbda..3c8c6157 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -7,6 +7,7 @@ from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env from lerobot.common.envs.pusht.env import PushtEnv from lerobot.common.envs.simxarm.env import SimxarmEnv +from lerobot.common.envs.aloha.env import AlohaEnv from .utils import DEVICE, init_config @@ -38,6 +39,27 @@ def print_spec_rollout(env): print("data from rollout:", simple_rollout(100)) +@pytest.mark.parametrize( + "task,from_pixels,pixels_only", + [ + ("sim_insertion", True, False), + ("sim_insertion", True, True), + ("sim_transfer_cube", True, False), + ("sim_transfer_cube", True, True), + # TODO(aliberts): Add aloha other tasks + ], +) +def test_aloha(task, from_pixels, pixels_only): + env = AlohaEnv( + task, + from_pixels=from_pixels, + pixels_only=pixels_only, + image_size=[3, 480, 640] if from_pixels else None, + ) + # print_spec_rollout(env) + check_env_specs(env) + + @pytest.mark.parametrize( "task,from_pixels,pixels_only", [