From 8fc1008809a0ba72c84207e857c59fcfa3bb4afc Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 29 Mar 2024 16:47:18 +0100 Subject: [PATCH] Fix imports --- envs/sim_aloha/aloha/env.py | 40 ++++++++++++++++++++++++++++++++ envs/sim_aloha/aloha/utils.py | 43 ----------------------------------- envs/sim_aloha/pyproject.toml | 2 +- 3 files changed, 41 insertions(+), 44 deletions(-) create mode 100644 envs/sim_aloha/aloha/env.py diff --git a/envs/sim_aloha/aloha/env.py b/envs/sim_aloha/aloha/env.py new file mode 100644 index 00000000..11020245 --- /dev/null +++ b/envs/sim_aloha/aloha/env.py @@ -0,0 +1,40 @@ +from dm_control import mujoco +from dm_control.rl import control + +from aloha.constants import ASSETS_DIR, DT +from aloha.tasks.sim import InsertionTask, TransferCubeTask +from aloha.tasks.sim_end_effector import ( + InsertionEndEffectorTask, + TransferCubeEndEffectorTask, +) + + +def make_env_task(task_name): + # time limit is controlled by StepCounter in env factory + time_limit = float("inf") + + if "sim_transfer_cube" in task_name: + xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = TransferCubeTask(random=False) + elif "sim_insertion" in task_name: + xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = InsertionTask(random=False) + elif "sim_end_effector_transfer_cube" in task_name: + raise NotImplementedError() + xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = TransferCubeEndEffectorTask(random=False) + elif "sim_end_effector_insertion" in task_name: + raise NotImplementedError() + xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = InsertionEndEffectorTask(random=False) + else: + raise NotImplementedError(task_name) + + env = control.Environment( + physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False + ) + return env diff --git a/envs/sim_aloha/aloha/utils.py b/envs/sim_aloha/aloha/utils.py index 2f8c4893..5ac8b955 100644 --- a/envs/sim_aloha/aloha/utils.py +++ b/envs/sim_aloha/aloha/utils.py @@ -1,47 +1,4 @@ import numpy as np -from dm_control import mujoco -from dm_control.rl import control - -from aloha.constants import ( - ASSETS_DIR, - DT, -) -from aloha.tasks.sim import InsertionTask, TransferCubeTask -from aloha.tasks.sim_end_effector import ( - InsertionEndEffectorTask, - TransferCubeEndEffectorTask, -) - - -def make_env_task(task_name): - # time limit is controlled by StepCounter in env factory - time_limit = float("inf") - - if "sim_transfer_cube" in task_name: - xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml" - physics = mujoco.Physics.from_xml_path(str(xml_path)) - task = TransferCubeTask(random=False) - elif "sim_insertion" in task_name: - xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml" - physics = mujoco.Physics.from_xml_path(str(xml_path)) - task = InsertionTask(random=False) - elif "sim_end_effector_transfer_cube" in task_name: - raise NotImplementedError() - xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml" - physics = mujoco.Physics.from_xml_path(str(xml_path)) - task = TransferCubeEndEffectorTask(random=False) - elif "sim_end_effector_insertion" in task_name: - raise NotImplementedError() - xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml" - physics = mujoco.Physics.from_xml_path(str(xml_path)) - task = InsertionEndEffectorTask(random=False) - else: - raise NotImplementedError(task_name) - - env = control.Environment( - physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False - ) - return env def sample_box_pose(): diff --git a/envs/sim_aloha/pyproject.toml b/envs/sim_aloha/pyproject.toml index 63dc9122..0cc095cc 100644 --- a/envs/sim_aloha/pyproject.toml +++ b/envs/sim_aloha/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sim_aloha" -version = "0.1.1" +version = "0.1.2" description = "ALOHA environment for LeRobot" authors = [ "Rémi Cadène ",