diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index ed5cb926..c8d10851 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -26,12 +26,14 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv: ) raise e - handle = f"{package_name}/{cfg.env.handle}" + gym_handle = f"{package_name}/{cfg.env.task}" if num_parallel_envs == 0: # non-batched version of the env that returns an observation of shape (c) - env = gym.make(handle, **kwargs) + env = gym.make(gym_handle, **kwargs) else: # batched version of the env that returns an observation of shape (b, c) - env = gym.vector.SyncVectorEnv([lambda: gym.make(handle, **kwargs) for _ in range(num_parallel_envs)]) + env = gym.vector.SyncVectorEnv( + [lambda: gym.make(gym_handle, **kwargs) for _ in range(num_parallel_envs)] + ) return env diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 146a4598..7a8d8b58 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -14,9 +14,7 @@ dataset_id: aloha_sim_insertion_human env: name: aloha - handle: AlohaInsertion-v0 - # TODO(aliberts): replace task with handle - task: insertion + task: AlohaInsertion-v0 from_pixels: True pixels_only: False image_size: [3, 480, 640] diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index aafd766a..a5fbcc25 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -14,9 +14,7 @@ dataset_id: pusht env: name: pusht - handle: PushT-v0 - # TODO(aliberts): replace task with handle - task: pusht + task: PushT-v0 from_pixels: True pixels_only: False image_size: 96 diff --git a/lerobot/configs/env/xarm.yaml b/lerobot/configs/env/xarm.yaml index 5eb1700e..8b3c72ef 100644 --- a/lerobot/configs/env/xarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -13,9 +13,7 @@ dataset_id: xarm_lift_medium env: name: xarm - handle: XarmLift-v0 - # TODO(aliberts): replace task with handle - task: lift + task: XarmLift-v0 from_pixels: True pixels_only: False image_size: 84 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 3bf09b5f..cca26902 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -162,7 +162,7 @@ def train(cfg: dict, out_dir=None, job_name=None): logger = Logger(out_dir, job_name, cfg) log_output_dir(out_dir) - logging.info(f"{cfg.env.handle=}") + logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.online_steps=}") logging.info(f"{cfg.env.action_repeat=}") diff --git a/tests/test_envs.py b/tests/test_envs.py index 8fcf3a48..72bc93c4 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -14,7 +14,7 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH @pytest.mark.parametrize( - "env_name, handle, obs_type", + "env_name, task, obs_type", [ # ("AlohaInsertion-v0", "state"), ("aloha", "AlohaInsertion-v0", "pixels"), @@ -29,10 +29,10 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH ("pusht", "PushT-v0", "pixels_agent_pos"), ], ) -def test_env(env_name, handle, obs_type): +def test_env(env_name, task, obs_type): package_name = f"gym_{env_name}" importlib.import_module(package_name) - env = gym.make(f"{package_name}/{handle}", obs_type=obs_type) + env = gym.make(f"{package_name}/{task}", obs_type=obs_type) check_env(env.unwrapped)