diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index 357a96ec..8f907650 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -206,7 +206,7 @@ class AlohaEnv(AbstractEnv): if self.from_pixels: if isinstance(self.image_size, int): image_shape = (3, self.image_size, self.image_size) - elif OmegaConf.is_list(self.image_size): + elif OmegaConf.is_list(self.image_size) or isinstance(self.image_size, list): assert len(self.image_size) == 3 # c h w assert self.image_size[0] == 3 # c is RGB image_shape = tuple(self.image_size) diff --git a/tests/test_envs.py b/tests/test_envs.py index 2beafbda..3c8c6157 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -7,6 +7,7 @@ from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env from lerobot.common.envs.pusht.env import PushtEnv from lerobot.common.envs.simxarm.env import SimxarmEnv +from lerobot.common.envs.aloha.env import AlohaEnv from .utils import DEVICE, init_config @@ -38,6 +39,27 @@ def print_spec_rollout(env): print("data from rollout:", simple_rollout(100)) +@pytest.mark.parametrize( + "task,from_pixels,pixels_only", + [ + ("sim_insertion", True, False), + ("sim_insertion", True, True), + ("sim_transfer_cube", True, False), + ("sim_transfer_cube", True, True), + # TODO(aliberts): Add aloha other tasks + ], +) +def test_aloha(task, from_pixels, pixels_only): + env = AlohaEnv( + task, + from_pixels=from_pixels, + pixels_only=pixels_only, + image_size=[3, 480, 640] if from_pixels else None, + ) + # print_spec_rollout(env) + check_env_specs(env) + + @pytest.mark.parametrize( "task,from_pixels,pixels_only", [