From 3d53e0fe0fad3710e567776e15f9cc944c9bfff0 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 29 Mar 2024 15:22:32 +0100 Subject: [PATCH] Move make_env_task logic to aloha --- envs/sim_aloha/aloha/utils.py | 43 +++++++++++++++++++++++++++++++++++ envs/sim_aloha/pyproject.toml | 2 +- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/envs/sim_aloha/aloha/utils.py b/envs/sim_aloha/aloha/utils.py index 5ac8b955..2f8c4893 100644 --- a/envs/sim_aloha/aloha/utils.py +++ b/envs/sim_aloha/aloha/utils.py @@ -1,4 +1,47 @@ import numpy as np +from dm_control import mujoco +from dm_control.rl import control + +from aloha.constants import ( + ASSETS_DIR, + DT, +) +from aloha.tasks.sim import InsertionTask, TransferCubeTask +from aloha.tasks.sim_end_effector import ( + InsertionEndEffectorTask, + TransferCubeEndEffectorTask, +) + + +def make_env_task(task_name): + # time limit is controlled by StepCounter in env factory + time_limit = float("inf") + + if "sim_transfer_cube" in task_name: + xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = TransferCubeTask(random=False) + elif "sim_insertion" in task_name: + xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = InsertionTask(random=False) + elif "sim_end_effector_transfer_cube" in task_name: + raise NotImplementedError() + xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = TransferCubeEndEffectorTask(random=False) + elif "sim_end_effector_insertion" in task_name: + raise NotImplementedError() + xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = InsertionEndEffectorTask(random=False) + else: + raise NotImplementedError(task_name) + + env = control.Environment( + physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False + ) + return env def sample_box_pose(): diff --git a/envs/sim_aloha/pyproject.toml b/envs/sim_aloha/pyproject.toml index a7b60845..63dc9122 100644 --- a/envs/sim_aloha/pyproject.toml +++ b/envs/sim_aloha/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sim_aloha" -version = "0.1.0" +version = "0.1.1" description = "ALOHA environment for LeRobot" authors = [ "Rémi Cadène ",