diff --git a/README.md b/README.md index 551270b5..83e75ff0 100644 --- a/README.md +++ b/README.md @@ -55,19 +55,10 @@ env=pusht ## TODO -- [x] priority update doesnt match FOWM or original paper -- [x] self.step=100000 should be updated at every step to adjust to horizon of planner -- [ ] prefetch replay buffer to speedup training -- [ ] parallelize env to speedup eval -- [ ] clean checkpointing / loading -- [ ] clean logging -- [ ] clean config -- [ ] clean hyperparameter tuning -- [ ] add pusht -- [ ] add aloha -- [ ] add act -- [ ] add diffusion -- [ ] add aloha 2 +If you don't know how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/users/Cadene/projects/1) + +Ask [Remi Cadene](re.cadene@gmail.com) for access if needed. + ## Profile diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 7397327d..851cc75b 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -73,11 +73,11 @@ def download(data_dir, dataset_id): data_dir.mkdir(parents=True, exist_ok=True) - gdown.download_folder(FOLDER_URLS[dataset_id], output=data_dir) + gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir)) # because of the 50 files limit per directory, two files episode 48 and 49 were missing - gdown.download(EP48_URLS[dataset_id], output=data_dir / "episode_48.hdf5", fuzzy=True) - gdown.download(EP49_URLS[dataset_id], output=data_dir / "episode_49.hdf5", fuzzy=True) + gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True) + gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True) class AlohaExperienceReplay(AbstractExperienceReplay): @@ -124,9 +124,6 @@ class AlohaExperienceReplay(AbstractExperienceReplay): def image_keys(self) -> list: return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]] - # def _is_downloaded(self) -> bool: - # return False - def _download_and_preproc(self): raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" if not raw_dir.is_dir(): diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 784242cc..1d56850e 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -1,4 +1,5 @@ import pickle +import zipfile from pathlib import Path from typing import Callable @@ -15,6 +16,22 @@ from torchrl.data.replay_buffers.writers import Writer from lerobot.common.datasets.abstract import AbstractExperienceReplay +def download(): + raise NotImplementedError() + import gdown + + url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" + download_path = "data.zip" + gdown.download(url, download_path, quiet=False) + print("Extracting...") + with zipfile.ZipFile(download_path, "r") as zip_f: + for member in zip_f.namelist(): + if member.startswith("data/xarm") and member.endswith(".pkl"): + print(member) + zip_f.extract(member=member) + Path(download_path).unlink() + + class SimxarmExperienceReplay(AbstractExperienceReplay): available_datasets = [ "xarm_lift_medium", @@ -48,8 +65,8 @@ class SimxarmExperienceReplay(AbstractExperienceReplay): ) def _download_and_preproc(self): - # download - # TODO(rcadene) + # TODO(rcadene): finish download + download() dataset_path = self.data_dir / "buffer.pkl" print(f"Using offline dataset '{dataset_path}'") diff --git a/lerobot/common/envs/abstract.py b/lerobot/common/envs/abstract.py new file mode 100644 index 00000000..2901e4d2 --- /dev/null +++ b/lerobot/common/envs/abstract.py @@ -0,0 +1,75 @@ +import abc +from collections import deque +from typing import Optional + +from tensordict import TensorDict +from torchrl.envs import EnvBase + + +class AbstractEnv(EnvBase): + def __init__( + self, + task, + frame_skip: int = 1, + from_pixels: bool = False, + pixels_only: bool = False, + image_size=None, + seed=1337, + device="cpu", + num_prev_obs=1, + num_prev_action=0, + ): + super().__init__(device=device, batch_size=[]) + self.task = task + self.frame_skip = frame_skip + self.from_pixels = from_pixels + self.pixels_only = pixels_only + self.image_size = image_size + self.num_prev_obs = num_prev_obs + self.num_prev_action = num_prev_action + self._rendering_hooks = [] + + if pixels_only: + assert from_pixels + if from_pixels: + assert image_size + + self._make_spec() + self._current_seed = self.set_seed(seed) + + if self.num_prev_obs > 0: + self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs) + self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs) + if self.num_prev_action > 0: + raise NotImplementedError() + # self._prev_action_queue = deque(maxlen=self.num_prev_action) + + def register_rendering_hook(self, func): + self._rendering_hooks.append(func) + + def call_rendering_hooks(self): + for func in self._rendering_hooks: + func(self) + + def reset_rendering_hooks(self): + self._rendering_hooks = [] + + @abc.abstractmethod + def render(self, mode="rgb_array", width=640, height=480): + raise NotImplementedError() + + @abc.abstractmethod + def _reset(self, tensordict: Optional[TensorDict] = None): + raise NotImplementedError() + + @abc.abstractmethod + def _step(self, tensordict: TensorDict): + raise NotImplementedError() + + @abc.abstractmethod + def _make_spec(self): + raise NotImplementedError() + + @abc.abstractmethod + def _set_seed(self, seed: Optional[int]): + raise NotImplementedError() diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml new file mode 100644 index 00000000..8002838c --- /dev/null +++ b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml new file mode 100644 index 00000000..05249ad2 --- /dev/null +++ b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml new file mode 100644 index 00000000..511f7947 --- /dev/null +++ b/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml @@ -0,0 +1,53 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml new file mode 100644 index 00000000..2d85a47c --- /dev/null +++ b/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/assets/scene.xml b/lerobot/common/envs/aloha/assets/scene.xml new file mode 100644 index 00000000..0f61b8a5 --- /dev/null +++ b/lerobot/common/envs/aloha/assets/scene.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/assets/tabletop.stl b/lerobot/common/envs/aloha/assets/tabletop.stl new file mode 100644 index 00000000..ab35cdf7 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/tabletop.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl new file mode 100644 index 00000000..534c7af9 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl new file mode 100644 index 00000000..d6a492c2 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl b/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl new file mode 100644 index 00000000..d6df86be Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl b/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl new file mode 100644 index 00000000..193014b6 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_1_base.stl b/lerobot/common/envs/aloha/assets/vx300s_1_base.stl new file mode 100644 index 00000000..5a7efda2 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_1_base.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl b/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl new file mode 100644 index 00000000..dc22aa7e Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl b/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl new file mode 100644 index 00000000..111c586e Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl b/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl new file mode 100644 index 00000000..8170d21c Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl b/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl new file mode 100644 index 00000000..39581f83 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl b/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl new file mode 100644 index 00000000..ab8423e9 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl b/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl new file mode 100644 index 00000000..043db9ca Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl b/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl new file mode 100644 index 00000000..36099b42 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl b/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl new file mode 100644 index 00000000..eba3caa2 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml b/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml new file mode 100644 index 00000000..93037ab7 --- /dev/null +++ b/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/assets/vx300s_left.xml b/lerobot/common/envs/aloha/assets/vx300s_left.xml new file mode 100644 index 00000000..3af6c235 --- /dev/null +++ b/lerobot/common/envs/aloha/assets/vx300s_left.xml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/assets/vx300s_right.xml b/lerobot/common/envs/aloha/assets/vx300s_right.xml new file mode 100644 index 00000000..495df478 --- /dev/null +++ b/lerobot/common/envs/aloha/assets/vx300s_right.xml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/constants.py b/lerobot/common/envs/aloha/constants.py new file mode 100644 index 00000000..e582e5f3 --- /dev/null +++ b/lerobot/common/envs/aloha/constants.py @@ -0,0 +1,163 @@ +from pathlib import Path + +### Simulation envs fixed constants +DT = 0.02 # 0.02 ms -> 1/0.2 = 50 hz +FPS = 50 + + +JOINTS = [ + # absolute joint position + "left_arm_waist", + "left_arm_shoulder", + "left_arm_elbow", + "left_arm_forearm_roll", + "left_arm_wrist_angle", + "left_arm_wrist_rotate", + # normalized gripper position 0: close, 1: open + "left_arm_gripper", + # absolute joint position + "right_arm_waist", + "right_arm_shoulder", + "right_arm_elbow", + "right_arm_forearm_roll", + "right_arm_wrist_angle", + "right_arm_wrist_rotate", + # normalized gripper position 0: close, 1: open + "right_arm_gripper", +] + +ACTIONS = [ + # position and quaternion for end effector + "left_arm_waist", + "left_arm_shoulder", + "left_arm_elbow", + "left_arm_forearm_roll", + "left_arm_wrist_angle", + "left_arm_wrist_rotate", + # normalized gripper position (0: close, 1: open) + "left_arm_gripper", + "right_arm_waist", + "right_arm_shoulder", + "right_arm_elbow", + "right_arm_forearm_roll", + "right_arm_wrist_angle", + "right_arm_wrist_rotate", + # normalized gripper position (0: close, 1: open) + "right_arm_gripper", +] + + +START_ARM_POSE = [ + 0, + -0.96, + 1.16, + 0, + -0.3, + 0, + 0.02239, + -0.02239, + 0, + -0.96, + 1.16, + 0, + -0.3, + 0, + 0.02239, + -0.02239, +] + +ASSETS_DIR = Path(__file__).parent.resolve() / "assets" # note: absolute path + +# Left finger position limits (qpos[7]), right_finger = -1 * left_finger +MASTER_GRIPPER_POSITION_OPEN = 0.02417 +MASTER_GRIPPER_POSITION_CLOSE = 0.01244 +PUPPET_GRIPPER_POSITION_OPEN = 0.05800 +PUPPET_GRIPPER_POSITION_CLOSE = 0.01844 + +# Gripper joint limits (qpos[6]) +MASTER_GRIPPER_JOINT_OPEN = 0.3083 +MASTER_GRIPPER_JOINT_CLOSE = -0.6842 +PUPPET_GRIPPER_JOINT_OPEN = 1.4910 +PUPPET_GRIPPER_JOINT_CLOSE = -0.6213 + +MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2 + +############################ Helper functions ############################ + + +def normalize_master_gripper_position(x): + return (x - MASTER_GRIPPER_POSITION_CLOSE) / ( + MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE + ) + + +def normalize_puppet_gripper_position(x): + return (x - PUPPET_GRIPPER_POSITION_CLOSE) / ( + PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE + ) + + +def unnormalize_master_gripper_position(x): + return x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE + + +def unnormalize_puppet_gripper_position(x): + return x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE + + +def convert_position_from_master_to_puppet(x): + return unnormalize_puppet_gripper_position(normalize_master_gripper_position(x)) + + +def normalizer_master_gripper_joint(x): + return (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + + +def normalize_puppet_gripper_joint(x): + return (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + + +def unnormalize_master_gripper_joint(x): + return x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE + + +def unnormalize_puppet_gripper_joint(x): + return x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE + + +def convert_join_from_master_to_puppet(x): + return unnormalize_puppet_gripper_joint(normalizer_master_gripper_joint(x)) + + +def normalize_master_gripper_velocity(x): + return x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + + +def normalize_puppet_gripper_velocity(x): + return x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + + +def convert_master_from_position_to_joint(x): + return ( + normalize_master_gripper_position(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + + MASTER_GRIPPER_JOINT_CLOSE + ) + + +def convert_master_from_joint_to_position(x): + return unnormalize_master_gripper_position( + (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + ) + + +def convert_puppet_from_position_to_join(x): + return ( + normalize_puppet_gripper_position(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + + PUPPET_GRIPPER_JOINT_CLOSE + ) + + +def convert_puppet_from_joint_to_position(x): + return unnormalize_puppet_gripper_position( + (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + ) diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py new file mode 100644 index 00000000..f0cbb25d --- /dev/null +++ b/lerobot/common/envs/aloha/env.py @@ -0,0 +1,306 @@ +import importlib +import logging +from collections import deque +from typing import Optional + +import einops +import numpy as np +import torch +from dm_control import mujoco +from dm_control.rl import control +from tensordict import TensorDict +from torchrl.data.tensor_specs import ( + BoundedTensorSpec, + CompositeSpec, + DiscreteTensorSpec, + UnboundedContinuousTensorSpec, +) + +from lerobot.common.envs.abstract import AbstractEnv +from lerobot.common.envs.aloha.constants import ( + ACTIONS, + ASSETS_DIR, + DT, + JOINTS, +) +from lerobot.common.envs.aloha.tasks.sim import BOX_POSE, InsertionTask, TransferCubeTask +from lerobot.common.envs.aloha.tasks.sim_end_effector import ( + InsertionEndEffectorTask, + TransferCubeEndEffectorTask, +) +from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose +from lerobot.common.utils import set_seed + +_has_gym = importlib.util.find_spec("gym") is not None + + +class AlohaEnv(AbstractEnv): + def __init__( + self, + task, + frame_skip: int = 1, + from_pixels: bool = False, + pixels_only: bool = False, + image_size=None, + seed=1337, + device="cpu", + num_prev_obs=1, + num_prev_action=0, + ): + super().__init__( + task=task, + frame_skip=frame_skip, + from_pixels=from_pixels, + pixels_only=pixels_only, + image_size=image_size, + seed=seed, + device=device, + num_prev_obs=num_prev_obs, + num_prev_action=num_prev_action, + ) + if not _has_gym: + raise ImportError("Cannot import gym.") + + if not from_pixels: + raise NotImplementedError() + + self._env = self._make_env_task(task) + + def render(self, mode="rgb_array", width=640, height=480): + # TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close) + image = self._env.physics.render(height=height, width=width, camera_id="top") + return image + + def _make_env_task(self, task_name): + # time limit is controlled by StepCounter in env factory + time_limit = float("inf") + + if "sim_transfer_cube" in task_name: + xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = TransferCubeTask(random=False) + elif "sim_insertion" in task_name: + xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = InsertionTask(random=False) + elif "sim_end_effector_transfer_cube" in task_name: + raise NotImplementedError() + xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = TransferCubeEndEffectorTask(random=False) + elif "sim_end_effector_insertion" in task_name: + raise NotImplementedError() + xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = InsertionEndEffectorTask(random=False) + else: + raise NotImplementedError(task_name) + + env = control.Environment( + physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False + ) + return env + + def _format_raw_obs(self, raw_obs): + if self.from_pixels: + image = torch.from_numpy(raw_obs["images"]["top"].copy()) + image = einops.rearrange(image, "h w c -> c h w") + obs = {"image": image.type(torch.float32) / 255.0} + + if not self.pixels_only: + obs["state"] = torch.from_numpy(raw_obs["qpos"]).type(torch.float32) + else: + # TODO(rcadene): + raise NotImplementedError() + # obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)} + + return obs + + def _reset(self, tensordict: Optional[TensorDict] = None): + td = tensordict + if td is None or td.is_empty(): + # we need to handle seed iteration, since self._env.reset() rely an internal _seed. + self._current_seed += 1 + self.set_seed(self._current_seed) + + # TODO(rcadene): do not use global variable for this + if "sim_transfer_cube" in self.task: + BOX_POSE[0] = sample_box_pose() # used in sim reset + elif "sim_insertion" in self.task: + BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset + + raw_obs = self._env.reset() + # TODO(rcadene): add assert + # assert self._current_seed == self._env._seed + + obs = self._format_raw_obs(raw_obs.observation) + + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue = deque( + [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) + if "state" in obs: + self._prev_obs_state_queue = deque( + [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs + + td = TensorDict( + { + "observation": TensorDict(obs, batch_size=[]), + "done": torch.tensor([False], dtype=torch.bool), + }, + batch_size=[], + ) + else: + raise NotImplementedError() + + self.call_rendering_hooks() + return td + + def _step(self, tensordict: TensorDict): + td = tensordict + action = td["action"].numpy() + # step expects shape=(4,) so we pad if necessary + # TODO(rcadene): add info["is_success"] and info["success"] ? + sum_reward = 0 + + if action.ndim == 1: + action = einops.repeat(action, "c -> t c", t=self.frame_skip) + else: + if self.frame_skip > 1: + raise NotImplementedError() + + num_action_steps = action.shape[0] + for i in range(num_action_steps): + _, reward, discount, raw_obs = self._env.step(action[i]) + del discount # not used + + # TOOD(rcadene): add an enum + success = done = reward == 4 + sum_reward += reward + obs = self._format_raw_obs(raw_obs) + + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue.append(obs["image"]) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) + if "state" in obs: + self._prev_obs_state_queue.append(obs["state"]) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs + + self.call_rendering_hooks() + + td = TensorDict( + { + "observation": TensorDict(obs, batch_size=[]), + "reward": torch.tensor([sum_reward], dtype=torch.float32), + # succes and done are true when coverage > self.success_threshold in env + "done": torch.tensor([done], dtype=torch.bool), + "success": torch.tensor([success], dtype=torch.bool), + }, + batch_size=[], + ) + return td + + def _make_spec(self): + obs = {} + from omegaconf import OmegaConf + + if self.from_pixels: + if isinstance(self.image_size, int): + image_shape = (3, self.image_size, self.image_size) + elif OmegaConf.is_list(self.image_size): + assert len(self.image_size) == 3 # c h w + assert self.image_size[0] == 3 # c is RGB + image_shape = tuple(self.image_size) + else: + raise ValueError(self.image_size) + if self.num_prev_obs > 0: + image_shape = (self.num_prev_obs + 1, *image_shape) + + obs["image"] = BoundedTensorSpec( + low=0, + high=1, + shape=image_shape, + dtype=torch.float32, + device=self.device, + ) + if not self.pixels_only: + state_shape = (len(JOINTS),) + if self.num_prev_obs > 0: + state_shape = (self.num_prev_obs + 1, *state_shape) + + obs["state"] = UnboundedContinuousTensorSpec( + # TODO: add low and high bounds + shape=state_shape, + dtype=torch.float32, + device=self.device, + ) + else: + # TODO(rcadene): add observation_space achieved_goal and desired_goal? + state_shape = (len(JOINTS),) + if self.num_prev_obs > 0: + state_shape = (self.num_prev_obs + 1, *state_shape) + + obs["state"] = UnboundedContinuousTensorSpec( + # TODO: add low and high bounds + shape=state_shape, + dtype=torch.float32, + device=self.device, + ) + self.observation_spec = CompositeSpec({"observation": obs}) + + # TODO(rcadene): valid when controling end effector? + # action_space = self._env.action_spec() + # self.action_spec = BoundedTensorSpec( + # low=action_space.minimum, + # high=action_space.maximum, + # shape=action_space.shape, + # dtype=torch.float32, + # device=self.device, + # ) + + # TODO(rcaene): add bounds (where are they????) + self.action_spec = BoundedTensorSpec( + shape=(len(ACTIONS)), + low=-1, + high=1, + dtype=torch.float32, + device=self.device, + ) + + self.reward_spec = UnboundedContinuousTensorSpec( + shape=(1,), + dtype=torch.float32, + device=self.device, + ) + + self.done_spec = CompositeSpec( + { + "done": DiscreteTensorSpec( + 2, + shape=(1,), + dtype=torch.bool, + device=self.device, + ), + "success": DiscreteTensorSpec( + 2, + shape=(1,), + dtype=torch.bool, + device=self.device, + ), + } + ) + + def _set_seed(self, seed: Optional[int]): + set_seed(seed) + # TODO(rcadene): seed the env + # self._env.seed(seed) + logging.warning("Aloha env is not seeded") diff --git a/lerobot/common/envs/aloha/tasks/sim.py b/lerobot/common/envs/aloha/tasks/sim.py new file mode 100644 index 00000000..ee1d0927 --- /dev/null +++ b/lerobot/common/envs/aloha/tasks/sim.py @@ -0,0 +1,219 @@ +import collections + +import numpy as np +from dm_control.suite import base + +from lerobot.common.envs.aloha.constants import ( + START_ARM_POSE, + normalize_puppet_gripper_position, + normalize_puppet_gripper_velocity, + unnormalize_puppet_gripper_position, +) + +BOX_POSE = [None] # to be changed from outside + +""" +Environment for simulated robot bi-manual manipulation, with joint position control +Action space: [left_arm_qpos (6), # absolute joint position + left_gripper_positions (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) + +Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position + left_gripper_position (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) + "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) + left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) + right_arm_qvel (6), # absolute joint velocity (rad) + right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) + "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' +""" + + +class BimanualViperXTask(base.Task): + def __init__(self, random=None): + super().__init__(random=random) + + def before_step(self, action, physics): + left_arm_action = action[:6] + right_arm_action = action[7 : 7 + 6] + normalized_left_gripper_action = action[6] + normalized_right_gripper_action = action[7 + 6] + + left_gripper_action = unnormalize_puppet_gripper_position(normalized_left_gripper_action) + right_gripper_action = unnormalize_puppet_gripper_position(normalized_right_gripper_action) + + full_left_gripper_action = [left_gripper_action, -left_gripper_action] + full_right_gripper_action = [right_gripper_action, -right_gripper_action] + + env_action = np.concatenate( + [left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action] + ) + super().before_step(env_action, physics) + return + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + super().initialize_episode(physics) + + @staticmethod + def get_qpos(physics): + qpos_raw = physics.data.qpos.copy() + left_qpos_raw = qpos_raw[:8] + right_qpos_raw = qpos_raw[8:16] + left_arm_qpos = left_qpos_raw[:6] + right_arm_qpos = right_qpos_raw[:6] + left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])] + right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])] + return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) + + @staticmethod + def get_qvel(physics): + qvel_raw = physics.data.qvel.copy() + left_qvel_raw = qvel_raw[:8] + right_qvel_raw = qvel_raw[8:16] + left_arm_qvel = left_qvel_raw[:6] + right_arm_qvel = right_qvel_raw[:6] + left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])] + right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])] + return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) + + @staticmethod + def get_env_state(physics): + raise NotImplementedError + + def get_observation(self, physics): + obs = collections.OrderedDict() + obs["qpos"] = self.get_qpos(physics) + obs["qvel"] = self.get_qvel(physics) + obs["env_state"] = self.get_env_state(physics) + obs["images"] = {} + obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top") + obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle") + obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close") + + return obs + + def get_reward(self, physics): + # return whether left gripper is holding the box + raise NotImplementedError + + +class TransferCubeTask(BimanualViperXTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside + # reset qpos, control and box position + with physics.reset_context(): + physics.named.data.qpos[:16] = START_ARM_POSE + np.copyto(physics.data.ctrl, START_ARM_POSE) + assert BOX_POSE[0] is not None + physics.named.data.qpos[-7:] = BOX_POSE[0] + # print(f"{BOX_POSE=}") + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether left gripper is holding the box + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, "geom") + name_geom_2 = physics.model.id2name(id_geom_2, "geom") + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_table = ("red_box", "table") in all_contact_pairs + + reward = 0 + if touch_right_gripper: + reward = 1 + if touch_right_gripper and not touch_table: # lifted + reward = 2 + if touch_left_gripper: # attempted transfer + reward = 3 + if touch_left_gripper and not touch_table: # successful transfer + reward = 4 + return reward + + +class InsertionTask(BimanualViperXTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside + # reset qpos, control and box position + with physics.reset_context(): + physics.named.data.qpos[:16] = START_ARM_POSE + np.copyto(physics.data.ctrl, START_ARM_POSE) + assert BOX_POSE[0] is not None + physics.named.data.qpos[-7 * 2 :] = BOX_POSE[0] # two objects + # print(f"{BOX_POSE=}") + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether peg touches the pin + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, "geom") + name_geom_2 = physics.model.id2name(id_geom_2, "geom") + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_left_gripper = ( + ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + ) + + peg_touch_table = ("red_peg", "table") in all_contact_pairs + socket_touch_table = ( + ("socket-1", "table") in all_contact_pairs + or ("socket-2", "table") in all_contact_pairs + or ("socket-3", "table") in all_contact_pairs + or ("socket-4", "table") in all_contact_pairs + ) + peg_touch_socket = ( + ("red_peg", "socket-1") in all_contact_pairs + or ("red_peg", "socket-2") in all_contact_pairs + or ("red_peg", "socket-3") in all_contact_pairs + or ("red_peg", "socket-4") in all_contact_pairs + ) + pin_touched = ("red_peg", "pin") in all_contact_pairs + + reward = 0 + if touch_left_gripper and touch_right_gripper: # touch both + reward = 1 + if ( + touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table) + ): # grasp both + reward = 2 + if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching + reward = 3 + if pin_touched: # successful insertion + reward = 4 + return reward diff --git a/lerobot/common/envs/aloha/tasks/sim_end_effector.py b/lerobot/common/envs/aloha/tasks/sim_end_effector.py new file mode 100644 index 00000000..d93c8330 --- /dev/null +++ b/lerobot/common/envs/aloha/tasks/sim_end_effector.py @@ -0,0 +1,263 @@ +import collections + +import numpy as np +from dm_control.suite import base + +from lerobot.common.envs.aloha.constants import ( + PUPPET_GRIPPER_POSITION_CLOSE, + START_ARM_POSE, + normalize_puppet_gripper_position, + normalize_puppet_gripper_velocity, + unnormalize_puppet_gripper_position, +) +from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose + +""" +Environment for simulated robot bi-manual manipulation, with end-effector control. +Action space: [left_arm_pose (7), # position and quaternion for end effector + left_gripper_positions (1), # normalized gripper position (0: close, 1: open) + right_arm_pose (7), # position and quaternion for end effector + right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) + +Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position + left_gripper_position (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) + "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) + left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) + right_arm_qvel (6), # absolute joint velocity (rad) + right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) + "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' +""" + + +class BimanualViperXEndEffectorTask(base.Task): + def __init__(self, random=None): + super().__init__(random=random) + + def before_step(self, action, physics): + a_len = len(action) // 2 + action_left = action[:a_len] + action_right = action[a_len:] + + # set mocap position and quat + # left + np.copyto(physics.data.mocap_pos[0], action_left[:3]) + np.copyto(physics.data.mocap_quat[0], action_left[3:7]) + # right + np.copyto(physics.data.mocap_pos[1], action_right[:3]) + np.copyto(physics.data.mocap_quat[1], action_right[3:7]) + + # set gripper + g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7]) + g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7]) + np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl])) + + def initialize_robots(self, physics): + # reset joint position + physics.named.data.qpos[:16] = START_ARM_POSE + + # reset mocap to align with end effector + # to obtain these numbers: + # (1) make an ee_sim env and reset to the same start_pose + # (2) get env._physics.named.data.xpos['vx300s_left/gripper_link'] + # get env._physics.named.data.xquat['vx300s_left/gripper_link'] + # repeat the same for right side + np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084]) + np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0]) + # right + np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084])) + np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0]) + + # reset gripper control + close_gripper_control = np.array( + [ + PUPPET_GRIPPER_POSITION_CLOSE, + -PUPPET_GRIPPER_POSITION_CLOSE, + PUPPET_GRIPPER_POSITION_CLOSE, + -PUPPET_GRIPPER_POSITION_CLOSE, + ] + ) + np.copyto(physics.data.ctrl, close_gripper_control) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + super().initialize_episode(physics) + + @staticmethod + def get_qpos(physics): + qpos_raw = physics.data.qpos.copy() + left_qpos_raw = qpos_raw[:8] + right_qpos_raw = qpos_raw[8:16] + left_arm_qpos = left_qpos_raw[:6] + right_arm_qpos = right_qpos_raw[:6] + left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])] + right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])] + return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) + + @staticmethod + def get_qvel(physics): + qvel_raw = physics.data.qvel.copy() + left_qvel_raw = qvel_raw[:8] + right_qvel_raw = qvel_raw[8:16] + left_arm_qvel = left_qvel_raw[:6] + right_arm_qvel = right_qvel_raw[:6] + left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])] + right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])] + return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) + + @staticmethod + def get_env_state(physics): + raise NotImplementedError + + def get_observation(self, physics): + # note: it is important to do .copy() + obs = collections.OrderedDict() + obs["qpos"] = self.get_qpos(physics) + obs["qvel"] = self.get_qvel(physics) + obs["env_state"] = self.get_env_state(physics) + obs["images"] = {} + obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top") + obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle") + obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close") + # used in scripted policy to obtain starting pose + obs["mocap_pose_left"] = np.concatenate( + [physics.data.mocap_pos[0], physics.data.mocap_quat[0]] + ).copy() + obs["mocap_pose_right"] = np.concatenate( + [physics.data.mocap_pos[1], physics.data.mocap_quat[1]] + ).copy() + + # used when replaying joint trajectory + obs["gripper_ctrl"] = physics.data.ctrl.copy() + return obs + + def get_reward(self, physics): + raise NotImplementedError + + +class TransferCubeEndEffectorTask(BimanualViperXEndEffectorTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + self.initialize_robots(physics) + # randomize box position + cube_pose = sample_box_pose() + box_start_idx = physics.model.name2id("red_box_joint", "joint") + np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose) + # print(f"randomized cube position to {cube_position}") + + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether left gripper is holding the box + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, "geom") + name_geom_2 = physics.model.id2name(id_geom_2, "geom") + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_table = ("red_box", "table") in all_contact_pairs + + reward = 0 + if touch_right_gripper: + reward = 1 + if touch_right_gripper and not touch_table: # lifted + reward = 2 + if touch_left_gripper: # attempted transfer + reward = 3 + if touch_left_gripper and not touch_table: # successful transfer + reward = 4 + return reward + + +class InsertionEndEffectorTask(BimanualViperXEndEffectorTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + self.initialize_robots(physics) + # randomize peg and socket position + peg_pose, socket_pose = sample_insertion_pose() + + def id2index(j_id): + return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky + + peg_start_id = physics.model.name2id("red_peg_joint", "joint") + peg_start_idx = id2index(peg_start_id) + np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose) + # print(f"randomized cube position to {cube_position}") + + socket_start_id = physics.model.name2id("blue_socket_joint", "joint") + socket_start_idx = id2index(socket_start_id) + np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose) + # print(f"randomized cube position to {cube_position}") + + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether peg touches the pin + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, "geom") + name_geom_2 = physics.model.id2name(id_geom_2, "geom") + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_left_gripper = ( + ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + ) + + peg_touch_table = ("red_peg", "table") in all_contact_pairs + socket_touch_table = ( + ("socket-1", "table") in all_contact_pairs + or ("socket-2", "table") in all_contact_pairs + or ("socket-3", "table") in all_contact_pairs + or ("socket-4", "table") in all_contact_pairs + ) + peg_touch_socket = ( + ("red_peg", "socket-1") in all_contact_pairs + or ("red_peg", "socket-2") in all_contact_pairs + or ("red_peg", "socket-3") in all_contact_pairs + or ("red_peg", "socket-4") in all_contact_pairs + ) + pin_touched = ("red_peg", "pin") in all_contact_pairs + + reward = 0 + if touch_left_gripper and touch_right_gripper: # touch both + reward = 1 + if ( + touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table) + ): # grasp both + reward = 2 + if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching + reward = 3 + if pin_touched: # successful insertion + reward = 4 + return reward diff --git a/lerobot/common/envs/aloha/utils.py b/lerobot/common/envs/aloha/utils.py new file mode 100644 index 00000000..5ac8b955 --- /dev/null +++ b/lerobot/common/envs/aloha/utils.py @@ -0,0 +1,39 @@ +import numpy as np + + +def sample_box_pose(): + x_range = [0.0, 0.2] + y_range = [0.4, 0.6] + z_range = [0.05, 0.05] + + ranges = np.vstack([x_range, y_range, z_range]) + cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) + + cube_quat = np.array([1, 0, 0, 0]) + return np.concatenate([cube_position, cube_quat]) + + +def sample_insertion_pose(): + # Peg + x_range = [0.1, 0.2] + y_range = [0.4, 0.6] + z_range = [0.05, 0.05] + + ranges = np.vstack([x_range, y_range, z_range]) + peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) + + peg_quat = np.array([1, 0, 0, 0]) + peg_pose = np.concatenate([peg_position, peg_quat]) + + # Socket + x_range = [-0.2, -0.1] + y_range = [0.4, 0.6] + z_range = [0.05, 0.05] + + ranges = np.vstack([x_range, y_range, z_range]) + socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) + + socket_quat = np.array([1, 0, 0, 0]) + socket_pose = np.concatenate([socket_position, socket_quat]) + + return peg_pose, socket_pose diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index dd8ab2f7..1d7eab5e 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -23,6 +23,11 @@ def make_env(cfg, transform=None): # assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range." clsfunc = PushtEnv + elif cfg.env.name == "aloha": + from lerobot.common.envs.aloha.env import AlohaEnv + + kwargs["task"] = cfg.env.task + clsfunc = AlohaEnv else: raise ValueError(cfg.env.name) diff --git a/lerobot/common/policies/act/backbone.py b/lerobot/common/policies/act/backbone.py new file mode 100644 index 00000000..6399d339 --- /dev/null +++ b/lerobot/common/policies/act/backbone.py @@ -0,0 +1,115 @@ +from typing import List + +import torch +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter + +from .position_encoding import build_position_encoding +from .utils import NestedTensor, is_main_process + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + def __init__( + self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool + ): + super().__init__() + # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this? + # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + # parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {"layer4": "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor): + xs = self.body(tensor) + return xs + # out: Dict[str, NestedTensor] = {} + # for name, x in xs.items(): + # m = tensor_list.mask + # assert m is not None + # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + # out[name] = NestedTensor(x, mask) + # return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + + def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), + norm_layer=FrozenBatchNorm2d, + ) # pretrained # TODO do we want frozen batch_norm?? + num_channels = 512 if name in ("resnet18", "resnet34") else 2048 + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for _, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = args.lr_backbone > 0 + return_interm_layers = args.masks + backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) + model = Joiner(backbone, position_embedding) + model.num_channels = backbone.num_channels + return model diff --git a/lerobot/common/policies/act/detr_vae.py b/lerobot/common/policies/act/detr_vae.py new file mode 100644 index 00000000..0f2626f7 --- /dev/null +++ b/lerobot/common/policies/act/detr_vae.py @@ -0,0 +1,212 @@ +import numpy as np +import torch +from torch import nn +from torch.autograd import Variable + +from .backbone import build_backbone +from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer + + +def reparametrize(mu, logvar): + std = logvar.div(2).exp() + eps = Variable(std.data.new(std.size()).normal_()) + return mu + std * eps + + +def get_sinusoid_encoding_table(n_position, d_hid): + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +class DETRVAE(nn.Module): + """This is the DETR module that performs object detection""" + + def __init__( + self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae + ): + """Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.encoder = encoder + self.vae = vae + hidden_dim = transformer.d_model + self.action_head = nn.Linear(hidden_dim, action_dim) + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.query_embed = nn.Embedding(num_queries, hidden_dim) + if backbones is not None: + self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) + # TODO(rcadene): understand what is env_state, and why it needs to be 7 + self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding + self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear( + hidden_dim, self.latent_dim * 2 + ) # project hidden state to latent std, var + self.register_buffer( + "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim) + ) # [CLS], qpos, a_seq + + # decoder extra parameters + self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding( + 2, hidden_dim + ) # learned position embedding for proprio and latent + + def forward(self, qpos, image, env_state, actions=None, is_pad=None): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + bs, _ = qpos.shape + ### Obtain latent z from action sequence + if self.vae and is_training: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) + qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim) + encoder_input = torch.cat( + [cls_embed, qpos_embed, action_embed], axis=1 + ) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + # cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding + # is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + mu = latent_info[:, : self.latent_dim] + logvar = latent_info[:, self.latent_dim :] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device) + latent_input = self.latent_out_proj(latent_sample) + + if self.backbones is not None: + # Image observation features and position embeddings + all_cam_features = [] + all_cam_pos = [] + for cam_id, _ in enumerate(self.camera_names): + features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED + features = features[0] # take the last layer feature + pos = pos[0] + all_cam_features.append(self.input_proj(features)) + all_cam_pos.append(pos) + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # fold camera dimension into width dimension + src = torch.cat(all_cam_features, axis=3) + pos = torch.cat(all_cam_pos, axis=3) + hs = self.transformer( + src, + None, + self.query_embed.weight, + pos, + latent_input, + proprio_input, + self.additional_pos_embed.weight, + )[0] + else: + qpos = self.input_proj_robot_state(qpos) + env_state = self.input_proj_env_state(env_state) + transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 + hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0] + a_hat = self.action_head(hs) + is_pad_hat = self.is_pad_head(hs) + return a_hat, is_pad_hat, [mu, logvar] + + +def mlp(input_dim, hidden_dim, output_dim, hidden_depth): + if hidden_depth == 0: + mods = [nn.Linear(input_dim, output_dim)] + else: + mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] + for _ in range(hidden_depth - 1): + mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] + mods.append(nn.Linear(hidden_dim, output_dim)) + trunk = nn.Sequential(*mods) + return trunk + + +def build_encoder(args): + d_model = args.hidden_dim # 256 + dropout = args.dropout # 0.1 + nhead = args.nheads # 8 + dim_feedforward = args.dim_feedforward # 2048 + num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder + normalize_before = args.pre_norm # False + activation = "relu" + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + return encoder + + +def build(args): + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + backbone = build_backbone(args) + backbones.append(backbone) + + transformer = build_transformer(args) + + encoder = build_encoder(args) + + model = DETRVAE( + backbones, + transformer, + encoder, + state_dim=args.state_dim, + action_dim=args.action_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, + vae=args.vae, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: {:.2f}M".format(n_parameters / 1e6)) + + return model diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py new file mode 100644 index 00000000..7928b3ab --- /dev/null +++ b/lerobot/common/policies/act/policy.py @@ -0,0 +1,218 @@ +import logging +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +import torchvision.transforms as transforms + +from lerobot.common.policies.act.detr_vae import build + + +def build_act_model_and_optimizer(cfg): + model = build(cfg) + + param_dicts = [ + {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, + { + "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], + "lr": cfg.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay) + + return model, optimizer + + +def kl_divergence(mu, logvar): + batch_size = mu.size(0) + assert batch_size != 0 + if mu.data.ndimension() == 4: + mu = mu.view(mu.size(0), mu.size(1)) + if logvar.data.ndimension() == 4: + logvar = logvar.view(logvar.size(0), logvar.size(1)) + + klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) + total_kld = klds.sum(1).mean(0, True) + dimension_wise_kld = klds.mean(0) + mean_kld = klds.mean(1).mean(0, True) + + return total_kld, dimension_wise_kld, mean_kld + + +class ActionChunkingTransformerPolicy(nn.Module): + def __init__(self, cfg, device, n_action_steps=1): + super().__init__() + self.cfg = cfg + self.n_action_steps = n_action_steps + self.device = device + self.model, self.optimizer = build_act_model_and_optimizer(cfg) + self.kl_weight = self.cfg.kl_weight + logging.info(f"KL Weight {self.kl_weight}") + + self.to(self.device) + + def update(self, replay_buffer, step): + del step + + start_time = time.time() + + self.train() + + num_slices = self.cfg.batch_size + batch_size = self.cfg.horizon * num_slices + + assert batch_size % self.cfg.horizon == 0 + assert batch_size % num_slices == 0 + + def process_batch(batch, horizon, num_slices): + # trajectory t = 64, horizon h = 16 + # (t h) ... -> t h ... + batch = batch.reshape(num_slices, horizon) + + image = batch["observation", "image", "top"] + image = image[:, 0] # first observation t=0 + # batch, num_cam, channel, height, width + image = image.unsqueeze(1) + assert image.ndim == 5 + image = image.float() + + state = batch["observation", "state"] + state = state[:, 0] # first observation t=0 + # batch, qpos_dim + assert state.ndim == 2 + + action = batch["action"] + # batch, seq, action_dim + assert action.ndim == 3 + assert action.shape[1] == horizon + + if self.cfg.n_obs_steps > 1: + raise NotImplementedError() + # # keep first n observations of the slice corresponding to t=[-1,0] + # image = image[:, : self.cfg.n_obs_steps] + # state = state[:, : self.cfg.n_obs_steps] + + out = { + "obs": { + "image": image.to(self.device, non_blocking=True), + "agent_pos": state.to(self.device, non_blocking=True), + }, + "action": action.to(self.device, non_blocking=True), + } + return out + + batch = replay_buffer.sample(batch_size) + batch = process_batch(batch, self.cfg.horizon, num_slices) + + data_s = time.time() - start_time + + loss = self.compute_loss(batch) + loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.cfg.grad_clip_norm, + error_if_nonfinite=False, + ) + + self.optimizer.step() + self.optimizer.zero_grad() + # self.lr_scheduler.step() + + info = { + "loss": loss.item(), + "grad_norm": float(grad_norm), + # "lr": self.lr_scheduler.get_last_lr()[0], + "lr": self.cfg.lr, + "data_s": data_s, + "update_s": time.time() - start_time, + } + + return info + + def save(self, fp): + torch.save(self.state_dict(), fp) + + def load(self, fp): + d = torch.load(fp) + self.load_state_dict(d) + + def compute_loss(self, batch): + loss_dict = self._forward( + qpos=batch["obs"]["agent_pos"], + image=batch["obs"]["image"], + actions=batch["action"], + ) + loss = loss_dict["loss"] + return loss + + @torch.no_grad() + def forward(self, observation, step_count): + # TODO(rcadene): remove unused step_count + del step_count + + self.eval() + + # TODO(rcadene): remove unsqueeze hack to add bsize=1 + observation["image"] = observation["image"].unsqueeze(0) + observation["state"] = observation["state"].unsqueeze(0) + + # TODO(rcadene): remove hack + # add 1 camera dimension + observation["image"] = observation["image"].unsqueeze(1) + + obs_dict = { + "image": observation["image"], + "agent_pos": observation["state"], + } + action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"]) + + if self.cfg.temporal_agg: + # TODO(rcadene): implement temporal aggregation + raise NotImplementedError() + # all_time_actions[[t], t:t+num_queries] = action + # actions_for_curr_step = all_time_actions[:, t] + # actions_populated = torch.all(actions_for_curr_step != 0, axis=1) + # actions_for_curr_step = actions_for_curr_step[actions_populated] + # k = 0.01 + # exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) + # exp_weights = exp_weights / exp_weights.sum() + # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) + # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) + + # remove bsize=1 + action = action.squeeze(0) + + # take first predicted action or n first actions + action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps] + return action + + def _forward(self, qpos, image, actions=None, is_pad=None): + env_state = None + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + image = normalize(image) + + is_training = actions is not None + if is_training: # training time + actions = actions[:, : self.model.num_queries] + if is_pad is not None: + is_pad = is_pad[:, : self.model.num_queries] + + a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad) + + all_l1 = F.l1_loss(actions, a_hat, reduction="none") + l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean() + + loss_dict = {} + loss_dict["l1"] = l1 + if self.cfg.vae: + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) + loss_dict["kl"] = total_kld[0] + loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight + else: + loss_dict["loss"] = loss_dict["l1"] + return loss_dict + else: + action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior + return action diff --git a/lerobot/common/policies/act/position_encoding.py b/lerobot/common/policies/act/position_encoding.py new file mode 100644 index 00000000..94e862f6 --- /dev/null +++ b/lerobot/common/policies/act/position_encoding.py @@ -0,0 +1,101 @@ +""" +Various positional encodings for the transformer. +""" +import math + +import torch +from torch import nn + +from .utils import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor): + x = tensor + # mask = tensor_list.mask + # assert mask is not None + # not_mask = ~mask + + not_mask = torch.ones_like(x[0, [0]]) + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = ( + torch.cat( + [ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], + dim=-1, + ) + .permute(2, 0, 1) + .unsqueeze(0) + .repeat(x.shape[0], 1, 1, 1) + ) + return pos + + +def build_position_encoding(args): + n_steps = args.hidden_dim // 2 + if args.position_embedding in ("v2", "sine"): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(n_steps, normalize=True) + elif args.position_embedding in ("v3", "learned"): + position_embedding = PositionEmbeddingLearned(n_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/lerobot/common/policies/act/transformer.py b/lerobot/common/policies/act/transformer.py new file mode 100644 index 00000000..b2bd3685 --- /dev/null +++ b/lerobot/common/policies/act/transformer.py @@ -0,0 +1,370 @@ +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn + + +class Transformer(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec + ) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + src, + mask, + query_embed, + pos_embed, + latent_input=None, + proprio_input=None, + additional_pos_embed=None, + ): + # TODO flatten only when input has H and W + if len(src.shape) == 4: # has H and W + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + # mask = mask.flatten(1) + + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + + addition_input = torch.stack([latent_input, proprio_input], axis=0) + src = torch.cat([addition_input, src], axis=0) + else: + assert len(src.shape) == 3 + # flatten NxHWxC to HWxNxC + bs, hw, c = src.shape + src = src.permute(1, 0, 2) + pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed) + hs = hs.transpose(1, 2) + return hs + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + return self.forward_post( + tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos + ) + + +def _get_clones(module, n): + return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") diff --git a/lerobot/common/policies/act/utils.py b/lerobot/common/policies/act/utils.py new file mode 100644 index 00000000..2ce92094 --- /dev/null +++ b/lerobot/common/policies/act/utils.py @@ -0,0 +1,477 @@ +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import datetime +import os +import pickle +import subprocess +import time +from collections import defaultdict, deque +from typing import List, Optional + +import torch +import torch.distributed as dist + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +from packaging import version +from torch import Tensor + +if version.parse(torchvision.__version__) < version.parse("0.7"): + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list, strict=False): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416 + return reduced_dict + + +class MetricLogger: + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + if torch.cuda.is_available(): + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) + else: + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + ) + mega_b = 1024.0 * 1024.0 + for i, obj in enumerate(iterable): + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / mega_b, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommited changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch, strict=False)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor: + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("not supported") + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to( + torch.int64 + ) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if version.parse(torchvision.__version__) < version.parse("0.7"): + if input.numel() > 0: + return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index a956cb4b..c5e45300 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -17,6 +17,12 @@ def make_policy(cfg): n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, **cfg.policy, ) + elif cfg.policy.name == "act": + from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy + + policy = ActionChunkingTransformerPolicy( + cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps + ) else: raise ValueError(cfg.policy.name) diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 5b5ecbb7..df464c75 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -15,11 +15,11 @@ env: task: sim_insertion_human from_pixels: True pixels_only: False - image_size: 96 + image_size: [3, 480, 640] action_repeat: 1 - episode_length: 300 + episode_length: 400 fps: ${fps} policy: - state_dim: 2 - action_dim: 2 + state_dim: 14 + action_dim: 14 diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml new file mode 100644 index 00000000..a52c3f54 --- /dev/null +++ b/lerobot/configs/policy/act.yaml @@ -0,0 +1,58 @@ +# @package _global_ + +offline_steps: 1344000 +online_steps: 0 + +eval_episodes: 1 +eval_freq: 10000 +save_freq: 100000 +log_freq: 250 + +horizon: 100 +n_obs_steps: 1 +n_latency_steps: 0 +# when temporal_agg=False, n_action_steps=horizon +n_action_steps: ${horizon} + +policy: + name: act + + pretrained_model_path: + + lr: 1e-5 + lr_backbone: 1e-5 + weight_decay: 1e-4 + grad_clip_norm: 10 + backbone: resnet18 + num_queries: ${horizon} # chunk_size + horizon: ${horizon} # chunk_size + kl_weight: 10 + hidden_dim: 512 + dim_feedforward: 3200 + enc_layers: 4 + dec_layers: 7 + nheads: 8 + #camera_names: [top, front_close, left_pillar, right_pillar] + camera_names: [top] + position_embedding: sine + masks: false + dilation: false + dropout: 0.1 + pre_norm: false + + vae: true + + batch_size: 8 + + per_alpha: 0.6 + per_beta: 0.4 + + balanced_sampling: false + utd: 1 + + n_obs_steps: ${n_obs_steps} + + temporal_agg: false + + state_dim: ??? + action_dim: ??? diff --git a/lerobot/scripts/download.py b/lerobot/scripts/download.py deleted file mode 100644 index ac935f48..00000000 --- a/lerobot/scripts/download.py +++ /dev/null @@ -1,22 +0,0 @@ -# TODO(rcadene): obsolete remove -import os -import zipfile - -import gdown - - -def download(): - url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" - download_path = "data.zip" - gdown.download(url, download_path, quiet=False) - print("Extracting...") - with zipfile.ZipFile(download_path, "r") as zip_f: - for member in zip_f.namelist(): - if member.startswith("data/xarm") and member.endswith(".pkl"): - print(member) - zip_f.extract(member=member) - os.remove(download_path) - - -if __name__ == "__main__": - download() diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index c9338dca..8d0b2e88 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -38,27 +38,18 @@ def eval_policy( successes = [] threads = [] for i in tqdm.tqdm(range(num_episodes)): - tensordict = env.reset() - ep_frames = [] - if save_video or (return_first_video and i == 0): - def rendering_callback(env, td=None): + def render_frame(env): ep_frames.append(env.render()) # noqa: B023 - # render first frame before rollout - rendering_callback(env) - else: - rendering_callback = None + env.register_rendering_hook(render_frame) with torch.inference_mode(): rollout = env.rollout( max_steps=max_steps, policy=policy, - callback=rendering_callback, - auto_reset=False, - tensordict=tensordict, auto_cast_to_device=True, ) # print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()])) @@ -85,6 +76,8 @@ def eval_policy( if return_first_video and i == 0: first_video = stacked_frames.transpose(0, 3, 1, 2) + env.reset_rendering_hooks() + for thread in threads: thread.join() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index be3bef8b..02b1efae 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path import hydra import numpy as np @@ -192,6 +193,8 @@ def train(cfg: dict, out_dir=None, job_name=None): num_episodes=cfg.eval_episodes, max_steps=cfg.env.episode_length // cfg.n_action_steps, return_first_video=True, + video_dir=Path(out_dir) / "eval", + save_video=True, ) log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline) if cfg.wandb.enable: diff --git a/poetry.lock b/poetry.lock index db4f8f3e..bbf7f353 100644 --- a/poetry.lock +++ b/poetry.lock @@ -488,6 +488,37 @@ files = [ {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, ] +[[package]] +name = "dm-control" +version = "1.0.14" +description = "Continuous control environments and MuJoCo Python bindings." +optional = false +python-versions = ">=3.8" +files = [ + {file = "dm_control-1.0.14-py3-none-any.whl", hash = "sha256:883c63244a7ebf598700a97564ed19fffd3479ca79efd090aed881609cdb9fc6"}, + {file = "dm_control-1.0.14.tar.gz", hash = "sha256:def1ece747b6f175c581150826b50f1a6134086dab34f8f3fd2d088ea035cf3d"}, +] + +[package.dependencies] +absl-py = ">=0.7.0" +dm-env = "*" +dm-tree = "!=0.1.2" +glfw = "*" +labmaze = "*" +lxml = "*" +mujoco = ">=2.3.7" +numpy = ">=1.9.0" +protobuf = ">=3.19.4" +pyopengl = ">=3.1.4" +pyparsing = ">=3.0.0" +requests = "*" +scipy = "*" +setuptools = "!=50.0.0" +tqdm = "*" + +[package.extras] +hdf5 = ["h5py"] + [[package]] name = "dm-env" version = "1.6" @@ -584,43 +615,6 @@ files = [ {file = "einops-0.7.0.tar.gz", hash = "sha256:b2b04ad6081a3b227080c9bf5e3ace7160357ff03043cd66cc5b2319eb7031d1"}, ] -[[package]] -name = "etils" -version = "1.7.0" -description = "Collection of common python utils" -optional = false -python-versions = ">=3.10" -files = [ - {file = "etils-1.7.0-py3-none-any.whl", hash = "sha256:61af8f7c242171de15e22e5da02d527cb9e677d11f8bcafe18fcc3548eee3e60"}, - {file = "etils-1.7.0.tar.gz", hash = "sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350"}, -] - -[package.dependencies] -fsspec = {version = "*", optional = true, markers = "extra == \"epath\""} -importlib_resources = {version = "*", optional = true, markers = "extra == \"epath\""} -typing_extensions = {version = "*", optional = true, markers = "extra == \"epy\""} -zipp = {version = "*", optional = true, markers = "extra == \"epath\""} - -[package.extras] -all = ["etils[array-types]", "etils[eapp]", "etils[ecolab]", "etils[edc]", "etils[enp]", "etils[epath-gcs]", "etils[epath-s3]", "etils[epath]", "etils[epy]", "etils[etqdm]", "etils[etree-dm]", "etils[etree-jax]", "etils[etree-tf]", "etils[etree]"] -array-types = ["etils[enp]"] -dev = ["chex", "dataclass_array", "optree", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-subtests", "pytest-xdist", "torch"] -docs = ["etils[all,dev]", "sphinx-apitree[ext]"] -eapp = ["absl-py", "etils[epy]", "simple_parsing"] -ecolab = ["etils[enp]", "etils[epy]", "etils[etree]", "jupyter", "mediapy", "numpy", "packaging", "protobuf"] -edc = ["etils[epy]"] -enp = ["etils[epy]", "numpy"] -epath = ["etils[epy]", "fsspec", "importlib_resources", "typing_extensions", "zipp"] -epath-gcs = ["etils[epath]", "gcsfs"] -epath-s3 = ["etils[epath]", "s3fs"] -epy = ["typing_extensions"] -etqdm = ["absl-py", "etils[epy]", "tqdm"] -etree = ["etils[array-types]", "etils[enp]", "etils[epy]", "etils[etqdm]"] -etree-dm = ["dm-tree", "etils[etree]"] -etree-jax = ["etils[etree]", "jax[cpu]"] -etree-tf = ["etils[etree]", "tensorflow"] -lazy-imports = ["etils[ecolab]"] - [[package]] name = "exceptiongroup" version = "1.2.0" @@ -988,21 +982,6 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link perf = ["ipython"] testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] -[[package]] -name = "importlib-resources" -version = "6.1.2" -description = "Read resources from Python packages" -optional = false -python-versions = ">=3.8" -files = [ - {file = "importlib_resources-6.1.2-py3-none-any.whl", hash = "sha256:9a0a862501dc38b68adebc82970140c9e4209fc99601782925178f8386339938"}, - {file = "importlib_resources-6.1.2.tar.gz", hash = "sha256:308abf8474e2dba5f867d279237cd4076482c3de7104a40b41426370e891549b"}, -] - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] - [[package]] name = "iniconfig" version = "2.0.0" @@ -1031,6 +1010,50 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "labmaze" +version = "1.0.6" +description = "LabMaze: DeepMind Lab's text maze generator." +optional = false +python-versions = "*" +files = [ + {file = "labmaze-1.0.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b2ddef976dfd8d992b19cfa6c633f2eba7576d759c2082da534e3f727479a84a"}, + {file = "labmaze-1.0.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:157efaa93228c8ccce5cae337902dd652093e0fba9d3a0f6506e4bee272bb66f"}, + {file = "labmaze-1.0.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3ce98b9541c5fe6a306e411e7d018121dd646f2c9978d763fad86f9f30c5f57"}, + {file = "labmaze-1.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e6433bd49bc541791de8191040526fddfebb77151620eb04203453f43ee486a"}, + {file = "labmaze-1.0.6-cp310-cp310-win_amd64.whl", hash = "sha256:6a507fc35961f1b1479708e2716f65e0d0611cefb55f31a77be29ce2339b6fef"}, + {file = "labmaze-1.0.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:a0c2cb9dec971814ea9c5d7150af15fa3964482131fa969e0afb94bd224348af"}, + {file = "labmaze-1.0.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c6ba9538d819543f4be448d36b4926a3881e53646a2b331ebb5a1f353047d05"}, + {file = "labmaze-1.0.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70635d1cdb0147a02efb6b3f607a52cdc51723bc3dcc42717a0d4ef55fa0a987"}, + {file = "labmaze-1.0.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff472793238bd9b6dabea8094594d6074ad3c111455de3afcae72f6c40c6817e"}, + {file = "labmaze-1.0.6-cp311-cp311-win_amd64.whl", hash = "sha256:2317e65e12fa3d1abecda7e0488dab15456cee8a2e717a586bfc8f02a91579e7"}, + {file = "labmaze-1.0.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:e36b6fadcd78f22057b597c1c77823e806a0987b3bdfbf850e14b6b5b502075e"}, + {file = "labmaze-1.0.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d1a4f8de29c2c3d7f14163759b69cd3f237093b85334c983619c1db5403a223b"}, + {file = "labmaze-1.0.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a394f8bb857fcaa2884b809d63e750841c2662a106cfe8c045f2112d201ac7d5"}, + {file = "labmaze-1.0.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d17abb69d4dfc56183afb5c317e8b2eaca0587abb3aabd2326efd3143c81f4e"}, + {file = "labmaze-1.0.6-cp312-cp312-win_amd64.whl", hash = "sha256:5af997598cc46b1929d1c5a1febc32fd56c75874fe481a2a5982c65cee8450c9"}, + {file = "labmaze-1.0.6-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:a4c5bc6e56baa55ce63b97569afec2f80cab0f6b952752a131e1f83eed190a53"}, + {file = "labmaze-1.0.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3955f24fe5f708e1e97495b4cfe284b70ae4fd51be5e17b75a6fc04ffbd67bca"}, + {file = "labmaze-1.0.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed96ddc0bb8d66df36428c94db83949fd84a15867e8250763a4c5e3d82104c54"}, + {file = "labmaze-1.0.6-cp37-cp37m-win_amd64.whl", hash = "sha256:3bd0458a29e55aa09f146e28a168d2e00b8ccf19e2259a3f71154cfff3536b1d"}, + {file = "labmaze-1.0.6-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:33f5154edc83dff55a150e54b60c8582fdafc7ec45195049809cbcc01f5e8f34"}, + {file = "labmaze-1.0.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0971055ef2a5f7d8517fdc42b67c057093698f1eb911f46faa7018867b73fcc9"}, + {file = "labmaze-1.0.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de18d09680007302abf49111f3fe822d8435e4fbc4468b9ec07d50a78e267865"}, + {file = "labmaze-1.0.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f18126066db2218a52853c7dd490b4c3d8129fc22eb3a47eb23007524b911d53"}, + {file = "labmaze-1.0.6-cp38-cp38-win_amd64.whl", hash = "sha256:f9aef09a76877342bb4d634b7e05f43b038a49c4f34adfb8f1b8ac57c29472f2"}, + {file = "labmaze-1.0.6-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5dd28899418f1b8b1c7d1e1b40a4593150a7cfa95ca91e23860b9785b82cc0ee"}, + {file = "labmaze-1.0.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:965569f37ee33090b4d4b3aa5aa7c9dcc4f62e2ae5d761e7f73ec76fc9d8aa96"}, + {file = "labmaze-1.0.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05eccfa98c0e781bc9f939076ae600b2e25ca736e123f2a530606aedec3b531c"}, + {file = "labmaze-1.0.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bee8c94e0fb3fc2d8180214947245c1d74a3489349a9da90b868296e77a521e9"}, + {file = "labmaze-1.0.6-cp39-cp39-win_amd64.whl", hash = "sha256:d486e9ca3a335ad628e3bd48a09c42f1aa5f51040952ef0fe32507afedcd694b"}, + {file = "labmaze-1.0.6.tar.gz", hash = "sha256:2e8de7094042a77d6972f1965cf5c9e8f971f1b34d225752f343190a825ebe73"}, +] + +[package.dependencies] +absl-py = "*" +numpy = ">=1.8.0" +setuptools = "!=50.0.0" + [[package]] name = "lazy-loader" version = "0.3" @@ -1076,6 +1099,99 @@ files = [ {file = "llvmlite-0.42.0.tar.gz", hash = "sha256:f92b09243c0cc3f457da8b983f67bd8e1295d0f5b3746c7a1861d7a99403854a"}, ] +[[package]] +name = "lxml" +version = "5.1.0" +description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." +optional = false +python-versions = ">=3.6" +files = [ + {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"}, + {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"}, + {file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"}, + {file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"}, + {file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2befa20a13f1a75c751f47e00929fb3433d67eb9923c2c0b364de449121f447c"}, + {file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22b7ee4c35f374e2c20337a95502057964d7e35b996b1c667b5c65c567d2252a"}, + {file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:bf8443781533b8d37b295016a4b53c1494fa9a03573c09ca5104550c138d5c05"}, + {file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"}, + {file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"}, + {file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"}, + {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:14deca1460b4b0f6b01f1ddc9557704e8b365f55c63070463f6c18619ebf964f"}, + {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"}, + {file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"}, + {file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"}, + {file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af8920ce4a55ff41167ddbc20077f5698c2e710ad3353d32a07d3264f3a2021e"}, + {file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7cfced4a069003d8913408e10ca8ed092c49a7f6cefee9bb74b6b3e860683b45"}, + {file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9e5ac3437746189a9b4121db2a7b86056ac8786b12e88838696899328fc44bb2"}, + {file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"}, + {file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"}, + {file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"}, + {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9aa543980ab1fbf1720969af1d99095a548ea42e00361e727c58a40832439114"}, + {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"}, + {file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"}, + {file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"}, + {file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16dd953fb719f0ffc5bc067428fc9e88f599e15723a85618c45847c96f11f431"}, + {file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16018f7099245157564d7148165132c70adb272fb5a17c048ba70d9cc542a1a1"}, + {file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:82cd34f1081ae4ea2ede3d52f71b7be313756e99b4b5f829f89b12da552d3aa3"}, + {file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:19a1bc898ae9f06bccb7c3e1dfd73897ecbbd2c96afe9095a6026016e5ca97b8"}, + {file = "lxml-5.1.0-cp312-cp312-win32.whl", hash = "sha256:13521a321a25c641b9ea127ef478b580b5ec82aa2e9fc076c86169d161798b01"}, + {file = "lxml-5.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:1ad17c20e3666c035db502c78b86e58ff6b5991906e55bdbef94977700c72623"}, + {file = "lxml-5.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:24ef5a4631c0b6cceaf2dbca21687e29725b7c4e171f33a8f8ce23c12558ded1"}, + {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d2900b7f5318bc7ad8631d3d40190b95ef2aa8cc59473b73b294e4a55e9f30f"}, + {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:601f4a75797d7a770daed8b42b97cd1bb1ba18bd51a9382077a6a247a12aa38d"}, + {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4b68c961b5cc402cbd99cca5eb2547e46ce77260eb705f4d117fd9c3f932b95"}, + {file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:afd825e30f8d1f521713a5669b63657bcfe5980a916c95855060048b88e1adb7"}, + {file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:262bc5f512a66b527d026518507e78c2f9c2bd9eb5c8aeeb9f0eb43fcb69dc67"}, + {file = "lxml-5.1.0-cp36-cp36m-win32.whl", hash = "sha256:e856c1c7255c739434489ec9c8aa9cdf5179785d10ff20add308b5d673bed5cd"}, + {file = "lxml-5.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:c7257171bb8d4432fe9d6fdde4d55fdbe663a63636a17f7f9aaba9bcb3153ad7"}, + {file = "lxml-5.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b9e240ae0ba96477682aa87899d94ddec1cc7926f9df29b1dd57b39e797d5ab5"}, + {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a96f02ba1bcd330807fc060ed91d1f7a20853da6dd449e5da4b09bfcc08fdcf5"}, + {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e3898ae2b58eeafedfe99e542a17859017d72d7f6a63de0f04f99c2cb125936"}, + {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61c5a7edbd7c695e54fca029ceb351fc45cd8860119a0f83e48be44e1c464862"}, + {file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3aeca824b38ca78d9ee2ab82bd9883083d0492d9d17df065ba3b94e88e4d7ee6"}, + {file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"}, + {file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"}, + {file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"}, + {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae15347a88cf8af0949a9872b57a320d2605ae069bcdf047677318bc0bba45b1"}, + {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"}, + {file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"}, + {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"}, + {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"}, + {file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:98f3f020a2b736566c707c8e034945c02aa94e124c24f77ca097c446f81b01f1"}, + {file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"}, + {file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"}, + {file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"}, + {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:52427a7eadc98f9e62cb1368a5079ae826f94f05755d2d567d93ee1bc3ceb354"}, + {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"}, + {file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"}, + {file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"}, + {file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f8b0c78e7aac24979ef09b7f50da871c2de2def043d468c4b41f512d831e912"}, + {file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9bcf86dfc8ff3e992fed847c077bd875d9e0ba2fa25d859c3a0f0f76f07f0c8d"}, + {file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:49a9b4af45e8b925e1cd6f3b15bbba2c81e7dba6dce170c677c9cda547411e14"}, + {file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:280f3edf15c2a967d923bcfb1f8f15337ad36f93525828b40a0f9d6c2ad24890"}, + {file = "lxml-5.1.0-cp39-cp39-win32.whl", hash = "sha256:ed7326563024b6e91fef6b6c7a1a2ff0a71b97793ac33dbbcf38f6005e51ff6e"}, + {file = "lxml-5.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:8d7b4beebb178e9183138f552238f7e6613162a42164233e2bda00cb3afac58f"}, + {file = "lxml-5.1.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9bd0ae7cc2b85320abd5e0abad5ccee5564ed5f0cc90245d2f9a8ef330a8deae"}, + {file = "lxml-5.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8c1d679df4361408b628f42b26a5d62bd3e9ba7f0c0e7969f925021554755aa"}, + {file = "lxml-5.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2ad3a8ce9e8a767131061a22cd28fdffa3cd2dc193f399ff7b81777f3520e372"}, + {file = "lxml-5.1.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:304128394c9c22b6569eba2a6d98392b56fbdfbad58f83ea702530be80d0f9df"}, + {file = "lxml-5.1.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d74fcaf87132ffc0447b3c685a9f862ffb5b43e70ea6beec2fb8057d5d2a1fea"}, + {file = "lxml-5.1.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:8cf5877f7ed384dabfdcc37922c3191bf27e55b498fecece9fd5c2c7aaa34c33"}, + {file = "lxml-5.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:877efb968c3d7eb2dad540b6cabf2f1d3c0fbf4b2d309a3c141f79c7e0061324"}, + {file = "lxml-5.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f14a4fb1c1c402a22e6a341a24c1341b4a3def81b41cd354386dcb795f83897"}, + {file = "lxml-5.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:25663d6e99659544ee8fe1b89b1a8c0aaa5e34b103fab124b17fa958c4a324a6"}, + {file = "lxml-5.1.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8b9f19df998761babaa7f09e6bc169294eefafd6149aaa272081cbddc7ba4ca3"}, + {file = "lxml-5.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e53d7e6a98b64fe54775d23a7c669763451340c3d44ad5e3a3b48a1efbdc96f"}, + {file = "lxml-5.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c3cd1fc1dc7c376c54440aeaaa0dcc803d2126732ff5c6b68ccd619f2e64be4f"}, + {file = "lxml-5.1.0.tar.gz", hash = "sha256:3eea6ed6e6c918e468e693c41ef07f3c3acc310b70ddd9cc72d9ef84bc9564ca"}, +] + +[package.extras] +cssselect = ["cssselect (>=0.7)"] +html5 = ["html5lib"] +htmlsoup = ["BeautifulSoup4"] +source = ["Cython (>=3.0.7)"] + [[package]] name = "markupsafe" version = "2.1.5" @@ -1188,42 +1304,40 @@ tests = ["pytest (>=4.6)"] [[package]] name = "mujoco" -version = "3.1.2" +version = "2.3.7" description = "MuJoCo Physics Simulator" optional = false python-versions = ">=3.8" files = [ - {file = "mujoco-3.1.2-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:fe6b3542695a5363f348ee45625b3492734f29cdc9f493ca25eae719f974370e"}, - {file = "mujoco-3.1.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f07e2d1f01f1401f1a503187016f8c017d9402618c659e1482243640a1e11288"}, - {file = "mujoco-3.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:93863eccc9d77d96ce62dda2a6f61cbd880379e8d774f802568d64b9613fce39"}, - {file = "mujoco-3.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3586c642390c16fef58b01a86071cec6814c471586e2f4115c3733c4aec64fb7"}, - {file = "mujoco-3.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0da77394c664945b78f199c627b609fe091ec0c4641b9d8f713637344a17821a"}, - {file = "mujoco-3.1.2-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:b6f12904d0478c191e4770ecf9006e20953f0488a2411a8ddc62592721c136dc"}, - {file = "mujoco-3.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f69b8d42b50c10f8d12df4948fc9d4dd6706841e7b163c1d7ce83448965acb1c"}, - {file = "mujoco-3.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10119e39b1f45fb76b18bea242fea1d6ccf4b2285f8bd5e2cb1e2cbdeb69bdcd"}, - {file = "mujoco-3.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a65868506dd45dddfe7be84857e57b49bc102334fc0439aa848a4d4d285d89b"}, - {file = "mujoco-3.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:92bc73972e39539f23a05bb411c45f9be17191fe01171ac15ffafed381ee4366"}, - {file = "mujoco-3.1.2-cp312-cp312-macosx_10_16_x86_64.whl", hash = "sha256:835d6b64ca4dc2f6a83291275fd48bd83edc888039d247958bf5b2c759db4340"}, - {file = "mujoco-3.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ce94ca3cf14fc519981674c5b85f1055356dcdcd63bbc0ec6c340084438f27f"}, - {file = "mujoco-3.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:250d9de4bd0d31fa4165faf01a1f838c429434f1263faacd95b977580f24eae7"}, - {file = "mujoco-3.1.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ea009d10bbf0aba9bc835f051d25f07a2c3edbaa06627ac2348766a1f3760b9"}, - {file = "mujoco-3.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:a0460d2ebdad4926f48b8c774da473e011c3b3afd0ccb6b6be1087b788c34011"}, - {file = "mujoco-3.1.2-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:4ca7cae89e258a338e02229edcf8f177b459ac5e9f859ffffa07fc2c9fcfb6aa"}, - {file = "mujoco-3.1.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:33b4fe9b5f891b29ef0fc2b0b975bc3a8a4b87774eecaf4364a83ddc6a7762ba"}, - {file = "mujoco-3.1.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ed230980f33bafaf1fa8b32ef25b82b069a245de15ee6ce7127e7e054cfad16"}, - {file = "mujoco-3.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41cc610ac40f325c9d49d9885ac6cb61822ed938f6c23cb183b261a7a28472ca"}, - {file = "mujoco-3.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:90a172b904a6ca8e6a1be80ab7c393aaff7592843a2a6853a4f97a9204031c41"}, - {file = "mujoco-3.1.2-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:93201291a0c5b573b4cbb19a6b08c99673f9fba167f402174eae5ffa23066d24"}, - {file = "mujoco-3.1.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0398985bb28c2686cdeeaf4ded46e602a49ec12115ac77474144ca940e5261c5"}, - {file = "mujoco-3.1.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2e76b5cb07ab3088c81966ac774d573df027fa5f4e78c20953a547528a2a698"}, - {file = "mujoco-3.1.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd5c3f4ae858e812cb3f03332693bcdc343b2bce55b164523acf52dea2736c9e"}, - {file = "mujoco-3.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:ca25ff2646b06609526ef8681c0e123cd854a53c9ff23cb91dd5058a2794dab4"}, - {file = "mujoco-3.1.2.tar.gz", hash = "sha256:53530bc1a91903f3fd4b1e99818cc38fbd9911700db29b2c9fc839f23bfacbb8"}, + {file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"}, + {file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a934315f858a4e0c4b90a682fde519471cfdd7baa64435179da8cd20d4ae3f99"}, + {file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:36513024330f88b5f9a43558efef5692b33599bffd5141029b690a27918ffcbe"}, + {file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d4eede8ba8210fbd3d3cd1dbf69e24dd1541aa74c5af5b8adbbbf65504b6dba"}, + {file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab85fafc9d5a091c712947573b7e694512d283876bf7f33ae3f8daad3a20c0db"}, + {file = "mujoco-2.3.7-cp310-cp310-win_amd64.whl", hash = "sha256:f8b7e13fef8c813d91b78f975ed0815157692777907ffa4b4be53a4edb75019b"}, + {file = "mujoco-2.3.7-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:779520216f72a8e370e3f0cdd71b45c3b7384c63331a3189194c930a3e7cff5c"}, + {file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9d4018053879016282d27ab7a91e292c72d44efb5a88553feacfe5b843dde103"}, + {file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:3149b16b8122ee62642474bfd2871064e8edc40235471cf5d84be3569afc0312"}, + {file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c08660a8d52ef3efde76095f0991e807703a950c1e882d2bcd984b9a846626f7"}, + {file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:426af8965f8636d94a0f75740c3024a62b3e585020ee817ef5208ec844a1ad94"}, + {file = "mujoco-2.3.7-cp311-cp311-win_amd64.whl", hash = "sha256:215415a8e98a4b50625beae859079d5e0810b2039e50420f0ba81763c34abb59"}, + {file = "mujoco-2.3.7-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:8b78d14f4c60cea3c58e046bd4de453fb5b9b33aca6a25fc91d39a53f3a5342a"}, + {file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5c6f5a51d6f537a4bf294cf73816f3a6384573f8f10a5452b044df2771412a96"}, + {file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:ea8911e6047f92d7d775701f37e4c093971b6def3160f01d0b6926e29a7e962e"}, + {file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7473a3de4dd1a8762d569ffb139196b4c5e7eca27d256df97b6cd4c66d2a09b2"}, + {file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40e7e2d8f93d2495ec74efec84e5118ecc6e1d85157a844789c73c9ac9a4e28e"}, + {file = "mujoco-2.3.7-cp38-cp38-win_amd64.whl", hash = "sha256:720bc228a2023b3b0ed6af78f5b0f8ea36867be321d473321555c57dbf6e4e5b"}, + {file = "mujoco-2.3.7-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:855e79686366442aa410246043b44f7d842d3900d68fe7e37feb42147db9d707"}, + {file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:98947f4a742d34d36f3c3f83e9167025bb0414bbaa4bd859b0673bdab9959963"}, + {file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:d42818f2ee5d1632dbce31d136ed5ff868db54b04e4e9aca0c5a3ac329f8a90f"}, + {file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9237e1ba14bced9449c31199e6d5be49547f3a4c99bc83b196af7ca45fd73b83"}, + {file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b728ea638245b150e2650c5433e6952e0ed3798c63e47e264574270caea2a3"}, + {file = "mujoco-2.3.7-cp39-cp39-win_amd64.whl", hash = "sha256:9c721a5042b99d948d5f0296a534bcce3f142c777c4d7642f503a539513f3912"}, + {file = "mujoco-2.3.7.tar.gz", hash = "sha256:422041f1ce37c6d151fbced1048df626837e94fe3cd9f813585907046336a7d0"}, ] [package.dependencies] absl-py = "*" -etils = {version = "*", extras = ["epath"]} glfw = "*" numpy = "*" pyopengl = "*" @@ -2016,6 +2130,20 @@ files = [ {file = "PyOpenGL-3.1.7.tar.gz", hash = "sha256:eef31a3888e6984fd4d8e6c9961b184c9813ca82604d37fe3da80eb000a76c86"}, ] +[[package]] +name = "pyparsing" +version = "3.1.2" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = false +python-versions = ">=3.6.8" +files = [ + {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, + {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + [[package]] name = "pysocks" version = "1.7.1" @@ -3125,4 +3253,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "c4d83579aed1c8c2e54cad7c8ec81b95a09ab8faff74fc9a4cb20bd00e4ddec6" +content-hash = "84cda58ab0670dcb1e2429b342f4f1b3c35f261d1201fc17acad5cc1ef2c6aa8" diff --git a/pyproject.toml b/pyproject.toml index ebce8f32..b89a04e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,13 +41,14 @@ mpmath = "^1.3.0" torch = "^2.2.1" tensordict = {git = "https://github.com/pytorch/tensordict"} torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"} -mujoco = "^3.1.2" +mujoco = "2.3.7" mujoco-py = "^2.1.2.14" gym = "^0.26.2" opencv-python = "^4.9.0.80" diffusers = "^0.26.3" torchvision = "^0.17.1" h5py = "^3.10.0" +dm-control = "1.0.14" [tool.poetry.group.dev.dependencies] diff --git a/sbatch.sh b/sbatch.sh index da52c472..cb5b285a 100644 --- a/sbatch.sh +++ b/sbatch.sh @@ -17,6 +17,7 @@ apptainer exec --nv \ ~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL source ~/.bashrc -conda activate fowm +#conda activate fowm +conda activate lerobot srun $CMD diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e63ae2c1..71c14951 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -12,10 +12,10 @@ from .utils import init_config # ("simxarm", "lift"), ("pusht", "pusht"), # TODO(aliberts): add aloha when dataset is available on hub - # ("aloha", "sim_insertion_human"), - # ("aloha", "sim_insertion_scripted"), - # ("aloha", "sim_transfer_cube_human"), - # ("aloha", "sim_transfer_cube_scripted"), + ("aloha", "sim_insertion_human"), + ("aloha", "sim_insertion_scripted"), + ("aloha", "sim_transfer_cube_human"), + ("aloha", "sim_transfer_cube_scripted"), ], ) def test_factory(env_name, dataset_id): diff --git a/tests/test_envs.py b/tests/test_envs.py index b51c441b..0f3c3c2c 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -83,6 +83,7 @@ def test_pusht(from_pixels, pixels_only): [ # "simxarm", "pusht", + "aloha", ], ) def test_factory(env_name):