diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e9edad05..2b161310 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,6 +18,12 @@ jobs: env: POETRY_VERSION: 1.8.1 DATA_DIR: tests/data + TMPDIR: ~/tmp + TEMP: ~/tmp + TMP: ~/tmp + PYOPENGL_PLATFORM: egl + MUJOCO_GL: egl + LEROBOT_TESTS_DEVICE: cpu steps: #---------------------------------------------- # check-out repo and set-up python @@ -26,11 +32,13 @@ jobs: uses: actions/checkout@v4 with: lfs: true + - name: Set up python id: setup-python uses: actions/setup-python@v5 with: python-version: '3.10' + #---------------------------------------------- # install & configure poetry #---------------------------------------------- @@ -40,6 +48,7 @@ jobs: with: path: ~/.local # the path depends on the OS key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache + - name: Install Poetry if: steps.restore-poetry-cache.outputs.cache-hit != 'true' uses: snok/install-poetry@v1 @@ -47,6 +56,7 @@ jobs: version: ${{ env.POETRY_VERSION }} virtualenvs-create: true installer-parallel: true + - name: Save cached Poetry installation if: | steps.restore-poetry-cache.outputs.cache-hit != 'true' && @@ -56,8 +66,10 @@ jobs: with: path: ~/.local # the path depends on the OS key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache + - name: Configure Poetry run: poetry config virtualenvs.in-project true + #---------------------------------------------- # install dependencies #---------------------------------------------- @@ -67,9 +79,21 @@ jobs: with: path: .venv key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }} + + - name: Info + run: | + sudo du -sh /tmp + sudo df -h + - name: Install dependencies if: steps.restore-dependencies-cache.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root + run: | + mkdir ~/tmp + echo $TMPDIR + echo $TEMP + echo $TMP + poetry install --no-interaction --no-root + - name: Save cached venv if: | steps.restore-dependencies-cache.outputs.cache-hit != 'true' && @@ -79,11 +103,16 @@ jobs: with: path: .venv key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }} + + - name: Install libegl1-mesa-dev (to use MUJOCO_GL=egl) + run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev + #---------------------------------------------- # install project #---------------------------------------------- - name: Install project run: poetry install --no-interaction + #---------------------------------------------- # run tests #---------------------------------------------- @@ -91,6 +120,7 @@ jobs: run: | source .venv/bin/activate pytest tests + - name: Test train pusht end-to-end run: | source .venv/bin/activate @@ -104,6 +134,7 @@ jobs: save_model=true \ save_freq=1 \ hydra.run.dir=tests/outputs/ + - name: Test eval pusht end-to-end run: | source .venv/bin/activate diff --git a/README.md b/README.md index b346a30f..5ccffc35 100644 --- a/README.md +++ b/README.md @@ -59,19 +59,10 @@ env=pusht ## TODO -- [x] priority update doesn't match FOWM or original paper -- [x] self.step=100000 should be updated at every step to adjust to the horizon of the planner -- [ ] prefetch replay buffer to speedup training -- [ ] parallelize env to speed up eval -- [ ] clean checkpointing / loading -- [ ] clean logging -- [ ] clean config -- [ ] clean hyperparameter tuning -- [ ] add pusht -- [ ] add aloha -- [ ] add act -- [ ] add diffusion -- [ ] add aloha 2 +If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/users/Cadene/projects/1) + +Ask [Remi Cadene](re.cadene@gmail.com) for access if needed. + ## Profile diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index d6a51246..514fa038 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -81,7 +81,10 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): def set_transform(self, transform): if not isinstance(transform, Compose): # required since torchrl calls `len(self._transform)` downstream - self._transform = Compose(transform) + if isinstance(transform, list): + self._transform = Compose(*transform) + else: + self._transform = Compose(transform) else: self._transform = transform diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index afc28b1c..3b53fed1 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -73,11 +73,11 @@ def download(data_dir, dataset_id): data_dir.mkdir(parents=True, exist_ok=True) - gdown.download_folder(FOLDER_URLS[dataset_id], output=data_dir) + gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir)) # because of the 50 files limit per directory, two files episode 48 and 49 were missing - gdown.download(EP48_URLS[dataset_id], output=data_dir / "episode_48.hdf5", fuzzy=True) - gdown.download(EP49_URLS[dataset_id], output=data_dir / "episode_49.hdf5", fuzzy=True) + gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True) + gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True) class AlohaExperienceReplay(AbstractExperienceReplay): @@ -124,9 +124,6 @@ class AlohaExperienceReplay(AbstractExperienceReplay): def image_keys(self) -> list: return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]] - # def _is_downloaded(self) -> bool: - # return False - def _download_and_preproc(self): raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" if not raw_dir.is_dir(): diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 29f40bc6..fd284ae2 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -5,7 +5,7 @@ from pathlib import Path import torch from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler -from lerobot.common.envs.transforms import NormalizeTransform +from lerobot.common.envs.transforms import NormalizeTransform, Prod DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) @@ -84,6 +84,16 @@ def make_offline_buffer( prefetch=prefetch if isinstance(prefetch, int) else None, ) + if cfg.policy.name == "tdmpc": + img_keys = [] + for key in offline_buffer.image_keys: + img_keys.append(("next", *key)) + img_keys += offline_buffer.image_keys + else: + img_keys = offline_buffer.image_keys + + transforms = [Prod(in_keys=img_keys, prod=1 / 255)] + if normalize: # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec stats = offline_buffer.compute_or_load_stats() @@ -92,11 +102,10 @@ def make_offline_buffer( in_keys = [("observation", "state"), ("action")] if cfg.policy.name == "tdmpc": - for key in offline_buffer.image_keys: - # TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc - in_keys.append(key) - # since we use next observations in tdmpc - in_keys.append(("next", *key)) + # TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now + in_keys += img_keys + # TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now. + in_keys += [("next", *key) for key in img_keys] in_keys.append(("next", "observation", "state")) if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": @@ -106,8 +115,11 @@ def make_offline_buffer( stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) - transform = NormalizeTransform(stats, in_keys, mode="min_max") - offline_buffer.set_transform(transform) + # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std + normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" + transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) + + offline_buffer.set_transform(transforms) if not overwrite_sampler: index = torch.arange(0, offline_buffer.num_samples, 1) diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 784242cc..1d56850e 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -1,4 +1,5 @@ import pickle +import zipfile from pathlib import Path from typing import Callable @@ -15,6 +16,22 @@ from torchrl.data.replay_buffers.writers import Writer from lerobot.common.datasets.abstract import AbstractExperienceReplay +def download(): + raise NotImplementedError() + import gdown + + url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" + download_path = "data.zip" + gdown.download(url, download_path, quiet=False) + print("Extracting...") + with zipfile.ZipFile(download_path, "r") as zip_f: + for member in zip_f.namelist(): + if member.startswith("data/xarm") and member.endswith(".pkl"): + print(member) + zip_f.extract(member=member) + Path(download_path).unlink() + + class SimxarmExperienceReplay(AbstractExperienceReplay): available_datasets = [ "xarm_lift_medium", @@ -48,8 +65,8 @@ class SimxarmExperienceReplay(AbstractExperienceReplay): ) def _download_and_preproc(self): - # download - # TODO(rcadene) + # TODO(rcadene): finish download + download() dataset_path = self.data_dir / "buffer.pkl" print(f"Using offline dataset '{dataset_path}'") diff --git a/lerobot/common/envs/abstract.py b/lerobot/common/envs/abstract.py new file mode 100644 index 00000000..0754fb76 --- /dev/null +++ b/lerobot/common/envs/abstract.py @@ -0,0 +1,80 @@ +import abc +from collections import deque +from typing import Optional + +from tensordict import TensorDict +from torchrl.envs import EnvBase + + +class AbstractEnv(EnvBase): + def __init__( + self, + task, + frame_skip: int = 1, + from_pixels: bool = False, + pixels_only: bool = False, + image_size=None, + seed=1337, + device="cpu", + num_prev_obs=1, + num_prev_action=0, + ): + super().__init__(device=device, batch_size=[]) + self.task = task + self.frame_skip = frame_skip + self.from_pixels = from_pixels + self.pixels_only = pixels_only + self.image_size = image_size + self.num_prev_obs = num_prev_obs + self.num_prev_action = num_prev_action + self._rendering_hooks = [] + + if pixels_only: + assert from_pixels + if from_pixels: + assert image_size + + self._make_env() + self._make_spec() + self._current_seed = self.set_seed(seed) + + if self.num_prev_obs > 0: + self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs) + self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs) + if self.num_prev_action > 0: + raise NotImplementedError() + # self._prev_action_queue = deque(maxlen=self.num_prev_action) + + def register_rendering_hook(self, func): + self._rendering_hooks.append(func) + + def call_rendering_hooks(self): + for func in self._rendering_hooks: + func(self) + + def reset_rendering_hooks(self): + self._rendering_hooks = [] + + @abc.abstractmethod + def render(self, mode="rgb_array", width=640, height=480): + raise NotImplementedError() + + @abc.abstractmethod + def _reset(self, tensordict: Optional[TensorDict] = None): + raise NotImplementedError() + + @abc.abstractmethod + def _step(self, tensordict: TensorDict): + raise NotImplementedError() + + @abc.abstractmethod + def _make_env(self): + raise NotImplementedError() + + @abc.abstractmethod + def _make_spec(self): + raise NotImplementedError() + + @abc.abstractmethod + def _set_seed(self, seed: Optional[int]): + raise NotImplementedError() diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml new file mode 100644 index 00000000..8002838c --- /dev/null +++ b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml new file mode 100644 index 00000000..05249ad2 --- /dev/null +++ b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml new file mode 100644 index 00000000..511f7947 --- /dev/null +++ b/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml @@ -0,0 +1,53 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml new file mode 100644 index 00000000..2d85a47c --- /dev/null +++ b/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/assets/scene.xml b/lerobot/common/envs/aloha/assets/scene.xml new file mode 100644 index 00000000..0f61b8a5 --- /dev/null +++ b/lerobot/common/envs/aloha/assets/scene.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/assets/tabletop.stl b/lerobot/common/envs/aloha/assets/tabletop.stl new file mode 100644 index 00000000..ab35cdf7 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/tabletop.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl new file mode 100644 index 00000000..534c7af9 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl new file mode 100644 index 00000000..d6a492c2 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl b/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl new file mode 100644 index 00000000..d6df86be Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl b/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl new file mode 100644 index 00000000..193014b6 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_1_base.stl b/lerobot/common/envs/aloha/assets/vx300s_1_base.stl new file mode 100644 index 00000000..5a7efda2 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_1_base.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl b/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl new file mode 100644 index 00000000..dc22aa7e Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl b/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl new file mode 100644 index 00000000..111c586e Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl b/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl new file mode 100644 index 00000000..8170d21c Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl b/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl new file mode 100644 index 00000000..39581f83 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl b/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl new file mode 100644 index 00000000..ab8423e9 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl b/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl new file mode 100644 index 00000000..043db9ca Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl b/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl new file mode 100644 index 00000000..36099b42 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl b/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl new file mode 100644 index 00000000..eba3caa2 Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml b/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml new file mode 100644 index 00000000..93037ab7 --- /dev/null +++ b/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/assets/vx300s_left.xml b/lerobot/common/envs/aloha/assets/vx300s_left.xml new file mode 100644 index 00000000..3af6c235 --- /dev/null +++ b/lerobot/common/envs/aloha/assets/vx300s_left.xml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/assets/vx300s_right.xml b/lerobot/common/envs/aloha/assets/vx300s_right.xml new file mode 100644 index 00000000..495df478 --- /dev/null +++ b/lerobot/common/envs/aloha/assets/vx300s_right.xml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/aloha/constants.py b/lerobot/common/envs/aloha/constants.py new file mode 100644 index 00000000..e582e5f3 --- /dev/null +++ b/lerobot/common/envs/aloha/constants.py @@ -0,0 +1,163 @@ +from pathlib import Path + +### Simulation envs fixed constants +DT = 0.02 # 0.02 ms -> 1/0.2 = 50 hz +FPS = 50 + + +JOINTS = [ + # absolute joint position + "left_arm_waist", + "left_arm_shoulder", + "left_arm_elbow", + "left_arm_forearm_roll", + "left_arm_wrist_angle", + "left_arm_wrist_rotate", + # normalized gripper position 0: close, 1: open + "left_arm_gripper", + # absolute joint position + "right_arm_waist", + "right_arm_shoulder", + "right_arm_elbow", + "right_arm_forearm_roll", + "right_arm_wrist_angle", + "right_arm_wrist_rotate", + # normalized gripper position 0: close, 1: open + "right_arm_gripper", +] + +ACTIONS = [ + # position and quaternion for end effector + "left_arm_waist", + "left_arm_shoulder", + "left_arm_elbow", + "left_arm_forearm_roll", + "left_arm_wrist_angle", + "left_arm_wrist_rotate", + # normalized gripper position (0: close, 1: open) + "left_arm_gripper", + "right_arm_waist", + "right_arm_shoulder", + "right_arm_elbow", + "right_arm_forearm_roll", + "right_arm_wrist_angle", + "right_arm_wrist_rotate", + # normalized gripper position (0: close, 1: open) + "right_arm_gripper", +] + + +START_ARM_POSE = [ + 0, + -0.96, + 1.16, + 0, + -0.3, + 0, + 0.02239, + -0.02239, + 0, + -0.96, + 1.16, + 0, + -0.3, + 0, + 0.02239, + -0.02239, +] + +ASSETS_DIR = Path(__file__).parent.resolve() / "assets" # note: absolute path + +# Left finger position limits (qpos[7]), right_finger = -1 * left_finger +MASTER_GRIPPER_POSITION_OPEN = 0.02417 +MASTER_GRIPPER_POSITION_CLOSE = 0.01244 +PUPPET_GRIPPER_POSITION_OPEN = 0.05800 +PUPPET_GRIPPER_POSITION_CLOSE = 0.01844 + +# Gripper joint limits (qpos[6]) +MASTER_GRIPPER_JOINT_OPEN = 0.3083 +MASTER_GRIPPER_JOINT_CLOSE = -0.6842 +PUPPET_GRIPPER_JOINT_OPEN = 1.4910 +PUPPET_GRIPPER_JOINT_CLOSE = -0.6213 + +MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2 + +############################ Helper functions ############################ + + +def normalize_master_gripper_position(x): + return (x - MASTER_GRIPPER_POSITION_CLOSE) / ( + MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE + ) + + +def normalize_puppet_gripper_position(x): + return (x - PUPPET_GRIPPER_POSITION_CLOSE) / ( + PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE + ) + + +def unnormalize_master_gripper_position(x): + return x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE + + +def unnormalize_puppet_gripper_position(x): + return x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE + + +def convert_position_from_master_to_puppet(x): + return unnormalize_puppet_gripper_position(normalize_master_gripper_position(x)) + + +def normalizer_master_gripper_joint(x): + return (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + + +def normalize_puppet_gripper_joint(x): + return (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + + +def unnormalize_master_gripper_joint(x): + return x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE + + +def unnormalize_puppet_gripper_joint(x): + return x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE + + +def convert_join_from_master_to_puppet(x): + return unnormalize_puppet_gripper_joint(normalizer_master_gripper_joint(x)) + + +def normalize_master_gripper_velocity(x): + return x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + + +def normalize_puppet_gripper_velocity(x): + return x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + + +def convert_master_from_position_to_joint(x): + return ( + normalize_master_gripper_position(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + + MASTER_GRIPPER_JOINT_CLOSE + ) + + +def convert_master_from_joint_to_position(x): + return unnormalize_master_gripper_position( + (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + ) + + +def convert_puppet_from_position_to_join(x): + return ( + normalize_puppet_gripper_position(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + + PUPPET_GRIPPER_JOINT_CLOSE + ) + + +def convert_puppet_from_joint_to_position(x): + return unnormalize_puppet_gripper_position( + (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + ) diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py new file mode 100644 index 00000000..1211a37a --- /dev/null +++ b/lerobot/common/envs/aloha/env.py @@ -0,0 +1,311 @@ +import importlib +import logging +from collections import deque +from typing import Optional + +import einops +import numpy as np +import torch +from dm_control import mujoco +from dm_control.rl import control +from tensordict import TensorDict +from torchrl.data.tensor_specs import ( + BoundedTensorSpec, + CompositeSpec, + DiscreteTensorSpec, + UnboundedContinuousTensorSpec, +) + +from lerobot.common.envs.abstract import AbstractEnv +from lerobot.common.envs.aloha.constants import ( + ACTIONS, + ASSETS_DIR, + DT, + JOINTS, +) +from lerobot.common.envs.aloha.tasks.sim import BOX_POSE, InsertionTask, TransferCubeTask +from lerobot.common.envs.aloha.tasks.sim_end_effector import ( + InsertionEndEffectorTask, + TransferCubeEndEffectorTask, +) +from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose +from lerobot.common.utils import set_seed + +_has_gym = importlib.util.find_spec("gym") is not None + + +class AlohaEnv(AbstractEnv): + def __init__( + self, + task, + frame_skip: int = 1, + from_pixels: bool = False, + pixels_only: bool = False, + image_size=None, + seed=1337, + device="cpu", + num_prev_obs=1, + num_prev_action=0, + ): + super().__init__( + task=task, + frame_skip=frame_skip, + from_pixels=from_pixels, + pixels_only=pixels_only, + image_size=image_size, + seed=seed, + device=device, + num_prev_obs=num_prev_obs, + num_prev_action=num_prev_action, + ) + + def _make_env(self): + if not _has_gym: + raise ImportError("Cannot import gym.") + + if not self.from_pixels: + raise NotImplementedError() + + self._env = self._make_env_task(self.task) + + def render(self, mode="rgb_array", width=640, height=480): + # TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close) + image = self._env.physics.render(height=height, width=width, camera_id="top") + return image + + def _make_env_task(self, task_name): + # time limit is controlled by StepCounter in env factory + time_limit = float("inf") + + if "sim_transfer_cube" in task_name: + xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = TransferCubeTask(random=False) + elif "sim_insertion" in task_name: + xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = InsertionTask(random=False) + elif "sim_end_effector_transfer_cube" in task_name: + raise NotImplementedError() + xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = TransferCubeEndEffectorTask(random=False) + elif "sim_end_effector_insertion" in task_name: + raise NotImplementedError() + xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = InsertionEndEffectorTask(random=False) + else: + raise NotImplementedError(task_name) + + env = control.Environment( + physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False + ) + return env + + def _format_raw_obs(self, raw_obs): + if self.from_pixels: + image = torch.from_numpy(raw_obs["images"]["top"].copy()) + image = einops.rearrange(image, "h w c -> c h w") + assert image.dtype == torch.uint8 + obs = {"image": {"top": image}} + + if not self.pixels_only: + obs["state"] = torch.from_numpy(raw_obs["qpos"]).type(torch.float32) + else: + # TODO(rcadene): + raise NotImplementedError() + # obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)} + + return obs + + def _reset(self, tensordict: Optional[TensorDict] = None): + td = tensordict + if td is None or td.is_empty(): + # we need to handle seed iteration, since self._env.reset() rely an internal _seed. + self._current_seed += 1 + self.set_seed(self._current_seed) + + # TODO(rcadene): do not use global variable for this + if "sim_transfer_cube" in self.task: + BOX_POSE[0] = sample_box_pose() # used in sim reset + elif "sim_insertion" in self.task: + BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset + + raw_obs = self._env.reset() + # TODO(rcadene): add assert + # assert self._current_seed == self._env._seed + + obs = self._format_raw_obs(raw_obs.observation) + + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue = deque( + [obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} + if "state" in obs: + self._prev_obs_state_queue = deque( + [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs + + td = TensorDict( + { + "observation": TensorDict(obs, batch_size=[]), + "done": torch.tensor([False], dtype=torch.bool), + }, + batch_size=[], + ) + else: + raise NotImplementedError() + + self.call_rendering_hooks() + return td + + def _step(self, tensordict: TensorDict): + td = tensordict + action = td["action"].numpy() + # step expects shape=(4,) so we pad if necessary + # TODO(rcadene): add info["is_success"] and info["success"] ? + sum_reward = 0 + + if action.ndim == 1: + action = einops.repeat(action, "c -> t c", t=self.frame_skip) + else: + if self.frame_skip > 1: + raise NotImplementedError() + + num_action_steps = action.shape[0] + for i in range(num_action_steps): + _, reward, discount, raw_obs = self._env.step(action[i]) + del discount # not used + + # TOOD(rcadene): add an enum + success = done = reward == 4 + sum_reward += reward + obs = self._format_raw_obs(raw_obs) + + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue.append(obs["image"]["top"]) + stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} + if "state" in obs: + self._prev_obs_state_queue.append(obs["state"]) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs + + self.call_rendering_hooks() + + td = TensorDict( + { + "observation": TensorDict(obs, batch_size=[]), + "reward": torch.tensor([sum_reward], dtype=torch.float32), + # succes and done are true when coverage > self.success_threshold in env + "done": torch.tensor([done], dtype=torch.bool), + "success": torch.tensor([success], dtype=torch.bool), + }, + batch_size=[], + ) + return td + + def _make_spec(self): + obs = {} + from omegaconf import OmegaConf + + if self.from_pixels: + if isinstance(self.image_size, int): + image_shape = (3, self.image_size, self.image_size) + elif OmegaConf.is_list(self.image_size): + assert len(self.image_size) == 3 # c h w + assert self.image_size[0] == 3 # c is RGB + image_shape = tuple(self.image_size) + else: + raise ValueError(self.image_size) + if self.num_prev_obs > 0: + image_shape = (self.num_prev_obs + 1, *image_shape) + + obs["image"] = { + "top": BoundedTensorSpec( + low=0, + high=255, + shape=image_shape, + dtype=torch.uint8, + device=self.device, + ) + } + if not self.pixels_only: + state_shape = (len(JOINTS),) + if self.num_prev_obs > 0: + state_shape = (self.num_prev_obs + 1, *state_shape) + + obs["state"] = UnboundedContinuousTensorSpec( + # TODO: add low and high bounds + shape=state_shape, + dtype=torch.float32, + device=self.device, + ) + else: + # TODO(rcadene): add observation_space achieved_goal and desired_goal? + state_shape = (len(JOINTS),) + if self.num_prev_obs > 0: + state_shape = (self.num_prev_obs + 1, *state_shape) + + obs["state"] = UnboundedContinuousTensorSpec( + # TODO: add low and high bounds + shape=state_shape, + dtype=torch.float32, + device=self.device, + ) + self.observation_spec = CompositeSpec({"observation": obs}) + + # TODO(rcadene): valid when controling end effector? + # action_space = self._env.action_spec() + # self.action_spec = BoundedTensorSpec( + # low=action_space.minimum, + # high=action_space.maximum, + # shape=action_space.shape, + # dtype=torch.float32, + # device=self.device, + # ) + + # TODO(rcaene): add bounds (where are they????) + self.action_spec = BoundedTensorSpec( + shape=(len(ACTIONS)), + low=-1, + high=1, + dtype=torch.float32, + device=self.device, + ) + + self.reward_spec = UnboundedContinuousTensorSpec( + shape=(1,), + dtype=torch.float32, + device=self.device, + ) + + self.done_spec = CompositeSpec( + { + "done": DiscreteTensorSpec( + 2, + shape=(1,), + dtype=torch.bool, + device=self.device, + ), + "success": DiscreteTensorSpec( + 2, + shape=(1,), + dtype=torch.bool, + device=self.device, + ), + } + ) + + def _set_seed(self, seed: Optional[int]): + set_seed(seed) + # TODO(rcadene): seed the env + # self._env.seed(seed) + logging.warning("Aloha env is not seeded") diff --git a/lerobot/common/envs/aloha/tasks/sim.py b/lerobot/common/envs/aloha/tasks/sim.py new file mode 100644 index 00000000..ee1d0927 --- /dev/null +++ b/lerobot/common/envs/aloha/tasks/sim.py @@ -0,0 +1,219 @@ +import collections + +import numpy as np +from dm_control.suite import base + +from lerobot.common.envs.aloha.constants import ( + START_ARM_POSE, + normalize_puppet_gripper_position, + normalize_puppet_gripper_velocity, + unnormalize_puppet_gripper_position, +) + +BOX_POSE = [None] # to be changed from outside + +""" +Environment for simulated robot bi-manual manipulation, with joint position control +Action space: [left_arm_qpos (6), # absolute joint position + left_gripper_positions (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) + +Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position + left_gripper_position (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) + "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) + left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) + right_arm_qvel (6), # absolute joint velocity (rad) + right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) + "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' +""" + + +class BimanualViperXTask(base.Task): + def __init__(self, random=None): + super().__init__(random=random) + + def before_step(self, action, physics): + left_arm_action = action[:6] + right_arm_action = action[7 : 7 + 6] + normalized_left_gripper_action = action[6] + normalized_right_gripper_action = action[7 + 6] + + left_gripper_action = unnormalize_puppet_gripper_position(normalized_left_gripper_action) + right_gripper_action = unnormalize_puppet_gripper_position(normalized_right_gripper_action) + + full_left_gripper_action = [left_gripper_action, -left_gripper_action] + full_right_gripper_action = [right_gripper_action, -right_gripper_action] + + env_action = np.concatenate( + [left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action] + ) + super().before_step(env_action, physics) + return + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + super().initialize_episode(physics) + + @staticmethod + def get_qpos(physics): + qpos_raw = physics.data.qpos.copy() + left_qpos_raw = qpos_raw[:8] + right_qpos_raw = qpos_raw[8:16] + left_arm_qpos = left_qpos_raw[:6] + right_arm_qpos = right_qpos_raw[:6] + left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])] + right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])] + return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) + + @staticmethod + def get_qvel(physics): + qvel_raw = physics.data.qvel.copy() + left_qvel_raw = qvel_raw[:8] + right_qvel_raw = qvel_raw[8:16] + left_arm_qvel = left_qvel_raw[:6] + right_arm_qvel = right_qvel_raw[:6] + left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])] + right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])] + return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) + + @staticmethod + def get_env_state(physics): + raise NotImplementedError + + def get_observation(self, physics): + obs = collections.OrderedDict() + obs["qpos"] = self.get_qpos(physics) + obs["qvel"] = self.get_qvel(physics) + obs["env_state"] = self.get_env_state(physics) + obs["images"] = {} + obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top") + obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle") + obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close") + + return obs + + def get_reward(self, physics): + # return whether left gripper is holding the box + raise NotImplementedError + + +class TransferCubeTask(BimanualViperXTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside + # reset qpos, control and box position + with physics.reset_context(): + physics.named.data.qpos[:16] = START_ARM_POSE + np.copyto(physics.data.ctrl, START_ARM_POSE) + assert BOX_POSE[0] is not None + physics.named.data.qpos[-7:] = BOX_POSE[0] + # print(f"{BOX_POSE=}") + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether left gripper is holding the box + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, "geom") + name_geom_2 = physics.model.id2name(id_geom_2, "geom") + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_table = ("red_box", "table") in all_contact_pairs + + reward = 0 + if touch_right_gripper: + reward = 1 + if touch_right_gripper and not touch_table: # lifted + reward = 2 + if touch_left_gripper: # attempted transfer + reward = 3 + if touch_left_gripper and not touch_table: # successful transfer + reward = 4 + return reward + + +class InsertionTask(BimanualViperXTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside + # reset qpos, control and box position + with physics.reset_context(): + physics.named.data.qpos[:16] = START_ARM_POSE + np.copyto(physics.data.ctrl, START_ARM_POSE) + assert BOX_POSE[0] is not None + physics.named.data.qpos[-7 * 2 :] = BOX_POSE[0] # two objects + # print(f"{BOX_POSE=}") + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether peg touches the pin + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, "geom") + name_geom_2 = physics.model.id2name(id_geom_2, "geom") + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_left_gripper = ( + ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + ) + + peg_touch_table = ("red_peg", "table") in all_contact_pairs + socket_touch_table = ( + ("socket-1", "table") in all_contact_pairs + or ("socket-2", "table") in all_contact_pairs + or ("socket-3", "table") in all_contact_pairs + or ("socket-4", "table") in all_contact_pairs + ) + peg_touch_socket = ( + ("red_peg", "socket-1") in all_contact_pairs + or ("red_peg", "socket-2") in all_contact_pairs + or ("red_peg", "socket-3") in all_contact_pairs + or ("red_peg", "socket-4") in all_contact_pairs + ) + pin_touched = ("red_peg", "pin") in all_contact_pairs + + reward = 0 + if touch_left_gripper and touch_right_gripper: # touch both + reward = 1 + if ( + touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table) + ): # grasp both + reward = 2 + if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching + reward = 3 + if pin_touched: # successful insertion + reward = 4 + return reward diff --git a/lerobot/common/envs/aloha/tasks/sim_end_effector.py b/lerobot/common/envs/aloha/tasks/sim_end_effector.py new file mode 100644 index 00000000..d93c8330 --- /dev/null +++ b/lerobot/common/envs/aloha/tasks/sim_end_effector.py @@ -0,0 +1,263 @@ +import collections + +import numpy as np +from dm_control.suite import base + +from lerobot.common.envs.aloha.constants import ( + PUPPET_GRIPPER_POSITION_CLOSE, + START_ARM_POSE, + normalize_puppet_gripper_position, + normalize_puppet_gripper_velocity, + unnormalize_puppet_gripper_position, +) +from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose + +""" +Environment for simulated robot bi-manual manipulation, with end-effector control. +Action space: [left_arm_pose (7), # position and quaternion for end effector + left_gripper_positions (1), # normalized gripper position (0: close, 1: open) + right_arm_pose (7), # position and quaternion for end effector + right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) + +Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position + left_gripper_position (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) + "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) + left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) + right_arm_qvel (6), # absolute joint velocity (rad) + right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) + "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' +""" + + +class BimanualViperXEndEffectorTask(base.Task): + def __init__(self, random=None): + super().__init__(random=random) + + def before_step(self, action, physics): + a_len = len(action) // 2 + action_left = action[:a_len] + action_right = action[a_len:] + + # set mocap position and quat + # left + np.copyto(physics.data.mocap_pos[0], action_left[:3]) + np.copyto(physics.data.mocap_quat[0], action_left[3:7]) + # right + np.copyto(physics.data.mocap_pos[1], action_right[:3]) + np.copyto(physics.data.mocap_quat[1], action_right[3:7]) + + # set gripper + g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7]) + g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7]) + np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl])) + + def initialize_robots(self, physics): + # reset joint position + physics.named.data.qpos[:16] = START_ARM_POSE + + # reset mocap to align with end effector + # to obtain these numbers: + # (1) make an ee_sim env and reset to the same start_pose + # (2) get env._physics.named.data.xpos['vx300s_left/gripper_link'] + # get env._physics.named.data.xquat['vx300s_left/gripper_link'] + # repeat the same for right side + np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084]) + np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0]) + # right + np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084])) + np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0]) + + # reset gripper control + close_gripper_control = np.array( + [ + PUPPET_GRIPPER_POSITION_CLOSE, + -PUPPET_GRIPPER_POSITION_CLOSE, + PUPPET_GRIPPER_POSITION_CLOSE, + -PUPPET_GRIPPER_POSITION_CLOSE, + ] + ) + np.copyto(physics.data.ctrl, close_gripper_control) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + super().initialize_episode(physics) + + @staticmethod + def get_qpos(physics): + qpos_raw = physics.data.qpos.copy() + left_qpos_raw = qpos_raw[:8] + right_qpos_raw = qpos_raw[8:16] + left_arm_qpos = left_qpos_raw[:6] + right_arm_qpos = right_qpos_raw[:6] + left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])] + right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])] + return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) + + @staticmethod + def get_qvel(physics): + qvel_raw = physics.data.qvel.copy() + left_qvel_raw = qvel_raw[:8] + right_qvel_raw = qvel_raw[8:16] + left_arm_qvel = left_qvel_raw[:6] + right_arm_qvel = right_qvel_raw[:6] + left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])] + right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])] + return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) + + @staticmethod + def get_env_state(physics): + raise NotImplementedError + + def get_observation(self, physics): + # note: it is important to do .copy() + obs = collections.OrderedDict() + obs["qpos"] = self.get_qpos(physics) + obs["qvel"] = self.get_qvel(physics) + obs["env_state"] = self.get_env_state(physics) + obs["images"] = {} + obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top") + obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle") + obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close") + # used in scripted policy to obtain starting pose + obs["mocap_pose_left"] = np.concatenate( + [physics.data.mocap_pos[0], physics.data.mocap_quat[0]] + ).copy() + obs["mocap_pose_right"] = np.concatenate( + [physics.data.mocap_pos[1], physics.data.mocap_quat[1]] + ).copy() + + # used when replaying joint trajectory + obs["gripper_ctrl"] = physics.data.ctrl.copy() + return obs + + def get_reward(self, physics): + raise NotImplementedError + + +class TransferCubeEndEffectorTask(BimanualViperXEndEffectorTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + self.initialize_robots(physics) + # randomize box position + cube_pose = sample_box_pose() + box_start_idx = physics.model.name2id("red_box_joint", "joint") + np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose) + # print(f"randomized cube position to {cube_position}") + + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether left gripper is holding the box + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, "geom") + name_geom_2 = physics.model.id2name(id_geom_2, "geom") + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_table = ("red_box", "table") in all_contact_pairs + + reward = 0 + if touch_right_gripper: + reward = 1 + if touch_right_gripper and not touch_table: # lifted + reward = 2 + if touch_left_gripper: # attempted transfer + reward = 3 + if touch_left_gripper and not touch_table: # successful transfer + reward = 4 + return reward + + +class InsertionEndEffectorTask(BimanualViperXEndEffectorTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + self.initialize_robots(physics) + # randomize peg and socket position + peg_pose, socket_pose = sample_insertion_pose() + + def id2index(j_id): + return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky + + peg_start_id = physics.model.name2id("red_peg_joint", "joint") + peg_start_idx = id2index(peg_start_id) + np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose) + # print(f"randomized cube position to {cube_position}") + + socket_start_id = physics.model.name2id("blue_socket_joint", "joint") + socket_start_idx = id2index(socket_start_id) + np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose) + # print(f"randomized cube position to {cube_position}") + + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether peg touches the pin + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, "geom") + name_geom_2 = physics.model.id2name(id_geom_2, "geom") + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_left_gripper = ( + ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + ) + + peg_touch_table = ("red_peg", "table") in all_contact_pairs + socket_touch_table = ( + ("socket-1", "table") in all_contact_pairs + or ("socket-2", "table") in all_contact_pairs + or ("socket-3", "table") in all_contact_pairs + or ("socket-4", "table") in all_contact_pairs + ) + peg_touch_socket = ( + ("red_peg", "socket-1") in all_contact_pairs + or ("red_peg", "socket-2") in all_contact_pairs + or ("red_peg", "socket-3") in all_contact_pairs + or ("red_peg", "socket-4") in all_contact_pairs + ) + pin_touched = ("red_peg", "pin") in all_contact_pairs + + reward = 0 + if touch_left_gripper and touch_right_gripper: # touch both + reward = 1 + if ( + touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table) + ): # grasp both + reward = 2 + if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching + reward = 3 + if pin_touched: # successful insertion + reward = 4 + return reward diff --git a/lerobot/common/envs/aloha/utils.py b/lerobot/common/envs/aloha/utils.py new file mode 100644 index 00000000..5ac8b955 --- /dev/null +++ b/lerobot/common/envs/aloha/utils.py @@ -0,0 +1,39 @@ +import numpy as np + + +def sample_box_pose(): + x_range = [0.0, 0.2] + y_range = [0.4, 0.6] + z_range = [0.05, 0.05] + + ranges = np.vstack([x_range, y_range, z_range]) + cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) + + cube_quat = np.array([1, 0, 0, 0]) + return np.concatenate([cube_position, cube_quat]) + + +def sample_insertion_pose(): + # Peg + x_range = [0.1, 0.2] + y_range = [0.4, 0.6] + z_range = [0.05, 0.05] + + ranges = np.vstack([x_range, y_range, z_range]) + peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) + + peg_quat = np.array([1, 0, 0, 0]) + peg_pose = np.concatenate([peg_position, peg_quat]) + + # Socket + x_range = [-0.2, -0.1] + y_range = [0.4, 0.6] + z_range = [0.05, 0.05] + + ranges = np.vstack([x_range, y_range, z_range]) + socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) + + socket_quat = np.array([1, 0, 0, 0]) + socket_pose = np.concatenate([socket_position, socket_quat]) + + return peg_pose, socket_pose diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 269009db..921cbad7 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -23,6 +23,11 @@ def make_env(cfg, transform=None): # assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range." clsfunc = PushtEnv + elif cfg.env.name == "aloha": + from lerobot.common.envs.aloha.env import AlohaEnv + + kwargs["task"] = cfg.env.task + clsfunc = AlohaEnv else: raise ValueError(cfg.env.name) diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py index ff49f791..4a7ccb2c 100644 --- a/lerobot/common/envs/pusht/env.py +++ b/lerobot/common/envs/pusht/env.py @@ -11,61 +11,52 @@ from torchrl.data.tensor_specs import ( DiscreteTensorSpec, UnboundedContinuousTensorSpec, ) -from torchrl.envs import EnvBase from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform +from lerobot.common.envs.abstract import AbstractEnv from lerobot.common.utils import set_seed _has_gym = importlib.util.find_spec("gym") is not None -class PushtEnv(EnvBase): +class PushtEnv(AbstractEnv): def __init__( self, + task="pusht", frame_skip: int = 1, from_pixels: bool = False, pixels_only: bool = False, image_size=None, seed=1337, device="cpu", - num_prev_obs=0, + num_prev_obs=1, num_prev_action=0, ): - super().__init__(device=device, batch_size=[]) - self.frame_skip = frame_skip - self.from_pixels = from_pixels - self.pixels_only = pixels_only - self.image_size = image_size - self.num_prev_obs = num_prev_obs - self.num_prev_action = num_prev_action - - if pixels_only: - assert from_pixels - if from_pixels: - assert image_size + super().__init__( + task=task, + frame_skip=frame_skip, + from_pixels=from_pixels, + pixels_only=pixels_only, + image_size=image_size, + seed=seed, + device=device, + num_prev_obs=num_prev_obs, + num_prev_action=num_prev_action, + ) + def _make_env(self): if not _has_gym: raise ImportError("Cannot import gym.") # TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on) # from lerobot.common.envs.pusht.pusht_env import PushTEnv - if not from_pixels: + if not self.from_pixels: raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv") from lerobot.common.envs.pusht.pusht_image_env import PushTImageEnv self._env = PushTImageEnv(render_size=self.image_size) - self._make_spec() - self._current_seed = self.set_seed(seed) - - if self.num_prev_obs > 0: - self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs) - self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs) - if self.num_prev_action > 0: - raise NotImplementedError() - # self._prev_action_queue = deque(maxlen=self.num_prev_action) - def render(self, mode="rgb_array", width=384, height=384): if width != height: raise NotImplementedError() @@ -122,6 +113,8 @@ class PushtEnv(EnvBase): ) else: raise NotImplementedError() + + self.call_rendering_hooks() return td def _step(self, tensordict: TensorDict): @@ -154,6 +147,8 @@ class PushtEnv(EnvBase): stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) obs = stacked_obs + self.call_rendering_hooks() + td = TensorDict( { "observation": TensorDict(obs, batch_size=[]), @@ -175,9 +170,9 @@ class PushtEnv(EnvBase): obs["image"] = BoundedTensorSpec( low=0, - high=1, + high=255, shape=image_shape, - dtype=torch.float32, + dtype=torch.uint8, device=self.device, ) if not self.pixels_only: diff --git a/lerobot/common/envs/pusht/pusht_image_env.py b/lerobot/common/envs/pusht/pusht_image_env.py index 5f7bc03c..0807e849 100644 --- a/lerobot/common/envs/pusht/pusht_image_env.py +++ b/lerobot/common/envs/pusht/pusht_image_env.py @@ -25,7 +25,7 @@ class PushTImageEnv(PushTEnv): img = super()._render_frame(mode="rgb_array") agent_pos = np.array(self.agent.position) - img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0) + img_obs = np.moveaxis(img, -1, 0) obs = {"image": img_obs, "agent_pos": agent_pos} # draw action diff --git a/lerobot/common/envs/simxarm.py b/lerobot/common/envs/simxarm.py index 24fd9ba4..d0612625 100644 --- a/lerobot/common/envs/simxarm.py +++ b/lerobot/common/envs/simxarm.py @@ -1,6 +1,8 @@ import importlib +from collections import deque from typing import Optional +import einops import numpy as np import torch from tensordict import TensorDict @@ -10,9 +12,9 @@ from torchrl.data.tensor_specs import ( DiscreteTensorSpec, UnboundedContinuousTensorSpec, ) -from torchrl.envs import EnvBase from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform +from lerobot.common.envs.abstract import AbstractEnv from lerobot.common.utils import set_seed MAX_NUM_ACTIONS = 4 @@ -21,7 +23,7 @@ _has_gym = importlib.util.find_spec("gym") is not None _has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym -class SimxarmEnv(EnvBase): +class SimxarmEnv(AbstractEnv): def __init__( self, task, @@ -31,19 +33,22 @@ class SimxarmEnv(EnvBase): image_size=None, seed=1337, device="cpu", + num_prev_obs=0, + num_prev_action=0, ): - super().__init__(device=device, batch_size=[]) - self.task = task - self.frame_skip = frame_skip - self.from_pixels = from_pixels - self.pixels_only = pixels_only - self.image_size = image_size - - if pixels_only: - assert from_pixels - if from_pixels: - assert image_size + super().__init__( + task=task, + frame_skip=frame_skip, + from_pixels=from_pixels, + pixels_only=pixels_only, + image_size=image_size, + seed=seed, + device=device, + num_prev_obs=num_prev_obs, + num_prev_action=num_prev_action, + ) + def _make_env(self): if not _has_simxarm: raise ImportError("Cannot import simxarm.") if not _has_gym: @@ -63,9 +68,6 @@ class SimxarmEnv(EnvBase): if "w" not in TASKS[self.task]["action_space"]: self._action_padding[-1] = 1.0 - self._make_spec() - self.set_seed(seed) - def render(self, mode="rgb_array", width=384, height=384): return self._env.render(mode, width=width, height=height) @@ -90,15 +92,33 @@ class SimxarmEnv(EnvBase): if td is None or td.is_empty(): raw_obs = self._env.reset() + obs = self._format_raw_obs(raw_obs) + + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue = deque( + [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) + if "state" in obs: + self._prev_obs_state_queue = deque( + [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs + td = TensorDict( { - "observation": self._format_raw_obs(raw_obs), + "observation": TensorDict(obs, batch_size=[]), "done": torch.tensor([False], dtype=torch.bool), }, batch_size=[], ) else: raise NotImplementedError() + + self.call_rendering_hooks() return td def _step(self, tensordict: TensorDict): @@ -108,10 +128,32 @@ class SimxarmEnv(EnvBase): action = np.concatenate([action, self._action_padding]) # TODO(rcadene): add info["is_success"] and info["success"] ? sum_reward = 0 - for _ in range(self.frame_skip): - raw_obs, reward, done, info = self._env.step(action) + + if action.ndim == 1: + action = einops.repeat(action, "c -> t c", t=self.frame_skip) + else: + if self.frame_skip > 1: + raise NotImplementedError() + + num_action_steps = action.shape[0] + for i in range(num_action_steps): + raw_obs, reward, done, info = self._env.step(action[i]) sum_reward += reward + obs = self._format_raw_obs(raw_obs) + + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue.append(obs["image"]) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) + if "state" in obs: + self._prev_obs_state_queue.append(obs["state"]) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs + + self.call_rendering_hooks() + td = TensorDict( { "observation": self._format_raw_obs(raw_obs), @@ -126,23 +168,36 @@ class SimxarmEnv(EnvBase): def _make_spec(self): obs = {} if self.from_pixels: + image_shape = (3, self.image_size, self.image_size) + if self.num_prev_obs > 0: + image_shape = (self.num_prev_obs + 1, *image_shape) + obs["image"] = BoundedTensorSpec( low=0, high=255, - shape=(3, self.image_size, self.image_size), + shape=image_shape, dtype=torch.uint8, device=self.device, ) if not self.pixels_only: + state_shape = (len(self._env.robot_state),) + if self.num_prev_obs > 0: + state_shape = (self.num_prev_obs + 1, *state_shape) + obs["state"] = UnboundedContinuousTensorSpec( - shape=(len(self._env.robot_state),), + shape=state_shape, dtype=torch.float32, device=self.device, ) else: # TODO(rcadene): add observation_space achieved_goal and desired_goal? + state_shape = self._env.observation_space["observation"].shape + if self.num_prev_obs > 0: + state_shape = (self.num_prev_obs + 1, *state_shape) + obs["state"] = UnboundedContinuousTensorSpec( - shape=self._env.observation_space["observation"].shape, + # TODO: + shape=state_shape, dtype=torch.float32, device=self.device, ) diff --git a/lerobot/common/envs/transforms.py b/lerobot/common/envs/transforms.py index 671c0827..4832c91b 100644 --- a/lerobot/common/envs/transforms.py +++ b/lerobot/common/envs/transforms.py @@ -1,5 +1,6 @@ from typing import Sequence +import torch from tensordict import TensorDictBase from tensordict.nn import dispatch from tensordict.utils import NestedKey @@ -7,19 +8,45 @@ from torchrl.envs.transforms import ObservationTransform, Transform class Prod(ObservationTransform): + invertible = True + def __init__(self, in_keys: Sequence[NestedKey], prod: float): super().__init__() self.in_keys = in_keys self.prod = prod + self.original_dtypes = {} + + def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase: + # _reset is called once when the environment reset to normalize the first observation + tensordict_reset = self._call(tensordict_reset) + return tensordict_reset + + @dispatch(source="in_keys", dest="out_keys") + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + return self._call(tensordict) def _call(self, td): for key in self.in_keys: - td[key] *= self.prod + if td.get(key, None) is None: + continue + self.original_dtypes[key] = td[key].dtype + td[key] = td[key].type(torch.float32) * self.prod + return td + + def _inv_call(self, td: TensorDictBase) -> TensorDictBase: + for key in self.in_keys: + if td.get(key, None) is None: + continue + td[key] = (td[key] / self.prod).type(self.original_dtypes[key]) return td def transform_observation_spec(self, obs_spec): for key in self.in_keys: - obs_spec[key].space.high *= self.prod + if obs_spec.get(key, None) is None: + continue + obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod + obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod + obs_spec[key].dtype = torch.float32 return obs_spec diff --git a/lerobot/common/policies/act/backbone.py b/lerobot/common/policies/act/backbone.py new file mode 100644 index 00000000..6399d339 --- /dev/null +++ b/lerobot/common/policies/act/backbone.py @@ -0,0 +1,115 @@ +from typing import List + +import torch +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter + +from .position_encoding import build_position_encoding +from .utils import NestedTensor, is_main_process + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + def __init__( + self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool + ): + super().__init__() + # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this? + # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + # parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {"layer4": "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor): + xs = self.body(tensor) + return xs + # out: Dict[str, NestedTensor] = {} + # for name, x in xs.items(): + # m = tensor_list.mask + # assert m is not None + # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + # out[name] = NestedTensor(x, mask) + # return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + + def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), + norm_layer=FrozenBatchNorm2d, + ) # pretrained # TODO do we want frozen batch_norm?? + num_channels = 512 if name in ("resnet18", "resnet34") else 2048 + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for _, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = args.lr_backbone > 0 + return_interm_layers = args.masks + backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) + model = Joiner(backbone, position_embedding) + model.num_channels = backbone.num_channels + return model diff --git a/lerobot/common/policies/act/detr_vae.py b/lerobot/common/policies/act/detr_vae.py new file mode 100644 index 00000000..0f2626f7 --- /dev/null +++ b/lerobot/common/policies/act/detr_vae.py @@ -0,0 +1,212 @@ +import numpy as np +import torch +from torch import nn +from torch.autograd import Variable + +from .backbone import build_backbone +from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer + + +def reparametrize(mu, logvar): + std = logvar.div(2).exp() + eps = Variable(std.data.new(std.size()).normal_()) + return mu + std * eps + + +def get_sinusoid_encoding_table(n_position, d_hid): + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +class DETRVAE(nn.Module): + """This is the DETR module that performs object detection""" + + def __init__( + self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae + ): + """Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.encoder = encoder + self.vae = vae + hidden_dim = transformer.d_model + self.action_head = nn.Linear(hidden_dim, action_dim) + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.query_embed = nn.Embedding(num_queries, hidden_dim) + if backbones is not None: + self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) + # TODO(rcadene): understand what is env_state, and why it needs to be 7 + self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding + self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear( + hidden_dim, self.latent_dim * 2 + ) # project hidden state to latent std, var + self.register_buffer( + "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim) + ) # [CLS], qpos, a_seq + + # decoder extra parameters + self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding( + 2, hidden_dim + ) # learned position embedding for proprio and latent + + def forward(self, qpos, image, env_state, actions=None, is_pad=None): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + bs, _ = qpos.shape + ### Obtain latent z from action sequence + if self.vae and is_training: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) + qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim) + encoder_input = torch.cat( + [cls_embed, qpos_embed, action_embed], axis=1 + ) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + # cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding + # is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + mu = latent_info[:, : self.latent_dim] + logvar = latent_info[:, self.latent_dim :] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device) + latent_input = self.latent_out_proj(latent_sample) + + if self.backbones is not None: + # Image observation features and position embeddings + all_cam_features = [] + all_cam_pos = [] + for cam_id, _ in enumerate(self.camera_names): + features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED + features = features[0] # take the last layer feature + pos = pos[0] + all_cam_features.append(self.input_proj(features)) + all_cam_pos.append(pos) + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # fold camera dimension into width dimension + src = torch.cat(all_cam_features, axis=3) + pos = torch.cat(all_cam_pos, axis=3) + hs = self.transformer( + src, + None, + self.query_embed.weight, + pos, + latent_input, + proprio_input, + self.additional_pos_embed.weight, + )[0] + else: + qpos = self.input_proj_robot_state(qpos) + env_state = self.input_proj_env_state(env_state) + transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 + hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0] + a_hat = self.action_head(hs) + is_pad_hat = self.is_pad_head(hs) + return a_hat, is_pad_hat, [mu, logvar] + + +def mlp(input_dim, hidden_dim, output_dim, hidden_depth): + if hidden_depth == 0: + mods = [nn.Linear(input_dim, output_dim)] + else: + mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] + for _ in range(hidden_depth - 1): + mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] + mods.append(nn.Linear(hidden_dim, output_dim)) + trunk = nn.Sequential(*mods) + return trunk + + +def build_encoder(args): + d_model = args.hidden_dim # 256 + dropout = args.dropout # 0.1 + nhead = args.nheads # 8 + dim_feedforward = args.dim_feedforward # 2048 + num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder + normalize_before = args.pre_norm # False + activation = "relu" + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + return encoder + + +def build(args): + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + backbone = build_backbone(args) + backbones.append(backbone) + + transformer = build_transformer(args) + + encoder = build_encoder(args) + + model = DETRVAE( + backbones, + transformer, + encoder, + state_dim=args.state_dim, + action_dim=args.action_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, + vae=args.vae, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: {:.2f}M".format(n_parameters / 1e6)) + + return model diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py new file mode 100644 index 00000000..d011cb76 --- /dev/null +++ b/lerobot/common/policies/act/policy.py @@ -0,0 +1,217 @@ +import logging +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +import torchvision.transforms as transforms + +from lerobot.common.policies.act.detr_vae import build + + +def build_act_model_and_optimizer(cfg): + model = build(cfg) + + param_dicts = [ + {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, + { + "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], + "lr": cfg.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay) + + return model, optimizer + + +def kl_divergence(mu, logvar): + batch_size = mu.size(0) + assert batch_size != 0 + if mu.data.ndimension() == 4: + mu = mu.view(mu.size(0), mu.size(1)) + if logvar.data.ndimension() == 4: + logvar = logvar.view(logvar.size(0), logvar.size(1)) + + klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) + total_kld = klds.sum(1).mean(0, True) + dimension_wise_kld = klds.mean(0) + mean_kld = klds.mean(1).mean(0, True) + + return total_kld, dimension_wise_kld, mean_kld + + +class ActionChunkingTransformerPolicy(nn.Module): + def __init__(self, cfg, device, n_action_steps=1): + super().__init__() + self.cfg = cfg + self.n_action_steps = n_action_steps + self.device = device + self.model, self.optimizer = build_act_model_and_optimizer(cfg) + self.kl_weight = self.cfg.kl_weight + logging.info(f"KL Weight {self.kl_weight}") + self.to(self.device) + + def update(self, replay_buffer, step): + del step + + start_time = time.time() + + self.train() + + num_slices = self.cfg.batch_size + batch_size = self.cfg.horizon * num_slices + + assert batch_size % self.cfg.horizon == 0 + assert batch_size % num_slices == 0 + + def process_batch(batch, horizon, num_slices): + # trajectory t = 64, horizon h = 16 + # (t h) ... -> t h ... + batch = batch.reshape(num_slices, horizon) + + image = batch["observation", "image", "top"] + image = image[:, 0] # first observation t=0 + # batch, num_cam, channel, height, width + image = image.unsqueeze(1) + assert image.ndim == 5 + image = image.float() + + state = batch["observation", "state"] + state = state[:, 0] # first observation t=0 + # batch, qpos_dim + assert state.ndim == 2 + + action = batch["action"] + # batch, seq, action_dim + assert action.ndim == 3 + assert action.shape[1] == horizon + + if self.cfg.n_obs_steps > 1: + raise NotImplementedError() + # # keep first n observations of the slice corresponding to t=[-1,0] + # image = image[:, : self.cfg.n_obs_steps] + # state = state[:, : self.cfg.n_obs_steps] + + out = { + "obs": { + "image": image.to(self.device, non_blocking=True), + "agent_pos": state.to(self.device, non_blocking=True), + }, + "action": action.to(self.device, non_blocking=True), + } + return out + + batch = replay_buffer.sample(batch_size) + batch = process_batch(batch, self.cfg.horizon, num_slices) + + data_s = time.time() - start_time + + loss = self.compute_loss(batch) + loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.cfg.grad_clip_norm, + error_if_nonfinite=False, + ) + + self.optimizer.step() + self.optimizer.zero_grad() + # self.lr_scheduler.step() + + info = { + "loss": loss.item(), + "grad_norm": float(grad_norm), + # "lr": self.lr_scheduler.get_last_lr()[0], + "lr": self.cfg.lr, + "data_s": data_s, + "update_s": time.time() - start_time, + } + + return info + + def save(self, fp): + torch.save(self.state_dict(), fp) + + def load(self, fp): + d = torch.load(fp) + self.load_state_dict(d) + + def compute_loss(self, batch): + loss_dict = self._forward( + qpos=batch["obs"]["agent_pos"], + image=batch["obs"]["image"], + actions=batch["action"], + ) + loss = loss_dict["loss"] + return loss + + @torch.no_grad() + def forward(self, observation, step_count): + # TODO(rcadene): remove unused step_count + del step_count + + self.eval() + + # TODO(rcadene): remove unsqueeze hack to add bsize=1 + observation["image", "top"] = observation["image", "top"].unsqueeze(0) + # observation["state"] = observation["state"].unsqueeze(0) + + # TODO(rcadene): remove hack + # add 1 camera dimension + observation["image", "top"] = observation["image", "top"].unsqueeze(1) + + obs_dict = { + "image": observation["image", "top"], + "agent_pos": observation["state"], + } + action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"]) + + if self.cfg.temporal_agg: + # TODO(rcadene): implement temporal aggregation + raise NotImplementedError() + # all_time_actions[[t], t:t+num_queries] = action + # actions_for_curr_step = all_time_actions[:, t] + # actions_populated = torch.all(actions_for_curr_step != 0, axis=1) + # actions_for_curr_step = actions_for_curr_step[actions_populated] + # k = 0.01 + # exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) + # exp_weights = exp_weights / exp_weights.sum() + # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) + # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) + + # remove bsize=1 + action = action.squeeze(0) + + # take first predicted action or n first actions + action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps] + return action + + def _forward(self, qpos, image, actions=None, is_pad=None): + env_state = None + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + image = normalize(image) + + is_training = actions is not None + if is_training: # training time + actions = actions[:, : self.model.num_queries] + if is_pad is not None: + is_pad = is_pad[:, : self.model.num_queries] + + a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad) + + all_l1 = F.l1_loss(actions, a_hat, reduction="none") + l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean() + + loss_dict = {} + loss_dict["l1"] = l1 + if self.cfg.vae: + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) + loss_dict["kl"] = total_kld[0] + loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight + else: + loss_dict["loss"] = loss_dict["l1"] + return loss_dict + else: + action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior + return action diff --git a/lerobot/common/policies/act/position_encoding.py b/lerobot/common/policies/act/position_encoding.py new file mode 100644 index 00000000..94e862f6 --- /dev/null +++ b/lerobot/common/policies/act/position_encoding.py @@ -0,0 +1,101 @@ +""" +Various positional encodings for the transformer. +""" +import math + +import torch +from torch import nn + +from .utils import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor): + x = tensor + # mask = tensor_list.mask + # assert mask is not None + # not_mask = ~mask + + not_mask = torch.ones_like(x[0, [0]]) + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = ( + torch.cat( + [ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], + dim=-1, + ) + .permute(2, 0, 1) + .unsqueeze(0) + .repeat(x.shape[0], 1, 1, 1) + ) + return pos + + +def build_position_encoding(args): + n_steps = args.hidden_dim // 2 + if args.position_embedding in ("v2", "sine"): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(n_steps, normalize=True) + elif args.position_embedding in ("v3", "learned"): + position_embedding = PositionEmbeddingLearned(n_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/lerobot/common/policies/act/transformer.py b/lerobot/common/policies/act/transformer.py new file mode 100644 index 00000000..b2bd3685 --- /dev/null +++ b/lerobot/common/policies/act/transformer.py @@ -0,0 +1,370 @@ +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn + + +class Transformer(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec + ) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + src, + mask, + query_embed, + pos_embed, + latent_input=None, + proprio_input=None, + additional_pos_embed=None, + ): + # TODO flatten only when input has H and W + if len(src.shape) == 4: # has H and W + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + # mask = mask.flatten(1) + + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + + addition_input = torch.stack([latent_input, proprio_input], axis=0) + src = torch.cat([addition_input, src], axis=0) + else: + assert len(src.shape) == 3 + # flatten NxHWxC to HWxNxC + bs, hw, c = src.shape + src = src.permute(1, 0, 2) + pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed) + hs = hs.transpose(1, 2) + return hs + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + return self.forward_post( + tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos + ) + + +def _get_clones(module, n): + return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") diff --git a/lerobot/common/policies/act/utils.py b/lerobot/common/policies/act/utils.py new file mode 100644 index 00000000..2ce92094 --- /dev/null +++ b/lerobot/common/policies/act/utils.py @@ -0,0 +1,477 @@ +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import datetime +import os +import pickle +import subprocess +import time +from collections import defaultdict, deque +from typing import List, Optional + +import torch +import torch.distributed as dist + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +from packaging import version +from torch import Tensor + +if version.parse(torchvision.__version__) < version.parse("0.7"): + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list, strict=False): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416 + return reduced_dict + + +class MetricLogger: + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + if torch.cuda.is_available(): + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) + else: + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + ) + mega_b = 1024.0 * 1024.0 + for i, obj in enumerate(iterable): + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / mega_b, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommited changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch, strict=False)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor: + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("not supported") + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to( + torch.int64 + ) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if version.parse(torchvision.__version__) < version.parse("0.7"): + if input.numel() > 0: + return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index a956cb4b..c5e45300 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -17,6 +17,12 @@ def make_policy(cfg): n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, **cfg.policy, ) + elif cfg.policy.name == "act": + from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy + + policy = ActionChunkingTransformerPolicy( + cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps + ) else: raise ValueError(cfg.policy.name) diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 5b5ecbb7..df464c75 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -15,11 +15,11 @@ env: task: sim_insertion_human from_pixels: True pixels_only: False - image_size: 96 + image_size: [3, 480, 640] action_repeat: 1 - episode_length: 300 + episode_length: 400 fps: ${fps} policy: - state_dim: 2 - action_dim: 2 + state_dim: 14 + action_dim: 14 diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml new file mode 100644 index 00000000..a52c3f54 --- /dev/null +++ b/lerobot/configs/policy/act.yaml @@ -0,0 +1,58 @@ +# @package _global_ + +offline_steps: 1344000 +online_steps: 0 + +eval_episodes: 1 +eval_freq: 10000 +save_freq: 100000 +log_freq: 250 + +horizon: 100 +n_obs_steps: 1 +n_latency_steps: 0 +# when temporal_agg=False, n_action_steps=horizon +n_action_steps: ${horizon} + +policy: + name: act + + pretrained_model_path: + + lr: 1e-5 + lr_backbone: 1e-5 + weight_decay: 1e-4 + grad_clip_norm: 10 + backbone: resnet18 + num_queries: ${horizon} # chunk_size + horizon: ${horizon} # chunk_size + kl_weight: 10 + hidden_dim: 512 + dim_feedforward: 3200 + enc_layers: 4 + dec_layers: 7 + nheads: 8 + #camera_names: [top, front_close, left_pillar, right_pillar] + camera_names: [top] + position_embedding: sine + masks: false + dilation: false + dropout: 0.1 + pre_norm: false + + vae: true + + batch_size: 8 + + per_alpha: 0.6 + per_beta: 0.4 + + balanced_sampling: false + utd: 1 + + n_obs_steps: ${n_obs_steps} + + temporal_agg: false + + state_dim: ??? + action_dim: ??? diff --git a/lerobot/scripts/download.py b/lerobot/scripts/download.py deleted file mode 100644 index ac935f48..00000000 --- a/lerobot/scripts/download.py +++ /dev/null @@ -1,22 +0,0 @@ -# TODO(rcadene): obsolete remove -import os -import zipfile - -import gdown - - -def download(): - url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" - download_path = "data.zip" - gdown.download(url, download_path, quiet=False) - print("Extracting...") - with zipfile.ZipFile(download_path, "r") as zip_f: - for member in zip_f.namelist(): - if member.startswith("data/xarm") and member.endswith(".pkl"): - print(member) - zip_f.extract(member=member) - os.remove(download_path) - - -if __name__ == "__main__": - download() diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 6435310a..7ba2812e 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -38,27 +38,18 @@ def eval_policy( successes = [] threads = [] for i in tqdm.tqdm(range(num_episodes)): - tensordict = env.reset() - ep_frames = [] - if save_video or (return_first_video and i == 0): - def rendering_callback(env, td=None): + def render_frame(env): ep_frames.append(env.render()) # noqa: B023 - # render first frame before rollout - rendering_callback(env) - else: - rendering_callback = None + env.register_rendering_hook(render_frame) with torch.inference_mode(): rollout = env.rollout( max_steps=max_steps, policy=policy, - callback=rendering_callback, - auto_reset=False, - tensordict=tensordict, auto_cast_to_device=True, ) # print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()])) @@ -85,6 +76,8 @@ def eval_policy( if return_first_video and i == 0: first_video = stacked_frames.transpose(0, 3, 1, 2) + env.reset_rendering_hooks() + for thread in threads: thread.join() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index f4b22604..c063caf8 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path import hydra import numpy as np @@ -192,6 +193,8 @@ def train(cfg: dict, out_dir=None, job_name=None): num_episodes=cfg.eval_episodes, max_steps=cfg.env.episode_length // cfg.n_action_steps, return_first_video=True, + video_dir=Path(out_dir) / "eval", + save_video=True, ) log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline) if cfg.wandb.enable: diff --git a/poetry.lock b/poetry.lock index db4f8f3e..59de0ec5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -340,69 +340,69 @@ files = [ [[package]] name = "cython" -version = "3.0.8" +version = "3.0.9" description = "The Cython compiler for writing C extensions in the Python language." optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ - {file = "Cython-3.0.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a846e0a38e2b24e9a5c5dc74b0e54c6e29420d88d1dafabc99e0fc0f3e338636"}, - {file = "Cython-3.0.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45523fdc2b78d79b32834cc1cc12dc2ca8967af87e22a3ee1bff20e77c7f5520"}, - {file = "Cython-3.0.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa0b7f3f841fe087410cab66778e2d3fb20ae2d2078a2be3dffe66c6574be39"}, - {file = "Cython-3.0.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e87294e33e40c289c77a135f491cd721bd089f193f956f7b8ed5aa2d0b8c558f"}, - {file = "Cython-3.0.8-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:a1df7a129344b1215c20096d33c00193437df1a8fcca25b71f17c23b1a44f782"}, - {file = "Cython-3.0.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:13c2a5e57a0358da467d97667297bf820b62a1a87ae47c5f87938b9bb593acbd"}, - {file = "Cython-3.0.8-cp310-cp310-win32.whl", hash = "sha256:96b028f044f5880e3cb18ecdcfc6c8d3ce9d0af28418d5ab464509f26d8adf12"}, - {file = "Cython-3.0.8-cp310-cp310-win_amd64.whl", hash = "sha256:8140597a8b5cc4f119a1190f5a2228a84f5ca6d8d9ec386cfce24663f48b2539"}, - {file = "Cython-3.0.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aae26f9663e50caf9657148403d9874eea41770ecdd6caf381d177c2b1bb82ba"}, - {file = "Cython-3.0.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:547eb3cdb2f8c6f48e6865d5a741d9dd051c25b3ce076fbca571727977b28ac3"}, - {file = "Cython-3.0.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a567d4b9ba70b26db89d75b243529de9e649a2f56384287533cf91512705bee"}, - {file = "Cython-3.0.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:51d1426263b0e82fb22bda8ea60dc77a428581cc19e97741011b938445d383f1"}, - {file = "Cython-3.0.8-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c26daaeccda072459b48d211415fd1e5507c06bcd976fa0d5b8b9f1063467d7b"}, - {file = "Cython-3.0.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:289ce7838208211cd166e975865fd73b0649bf118170b6cebaedfbdaf4a37795"}, - {file = "Cython-3.0.8-cp311-cp311-win32.whl", hash = "sha256:c8aa05f5e17f8042a3be052c24f2edc013fb8af874b0bf76907d16c51b4e7871"}, - {file = "Cython-3.0.8-cp311-cp311-win_amd64.whl", hash = "sha256:000dc9e135d0eec6ecb2b40a5b02d0868a2f8d2e027a41b0fe16a908a9e6de02"}, - {file = "Cython-3.0.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:90d3fe31db55685d8cb97d43b0ec39ef614fcf660f83c77ed06aa670cb0e164f"}, - {file = "Cython-3.0.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e24791ddae2324e88e3c902a765595c738f19ae34ee66bfb1a6dac54b1833419"}, - {file = "Cython-3.0.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f020fa1c0552052e0660790b8153b79e3fc9a15dbd8f1d0b841fe5d204a6ae6"}, - {file = "Cython-3.0.8-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18bfa387d7a7f77d7b2526af69a65dbd0b731b8d941aaff5becff8e21f6d7717"}, - {file = "Cython-3.0.8-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fe81b339cffd87c0069c6049b4d33e28bdd1874625ee515785bf42c9fdff3658"}, - {file = "Cython-3.0.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:80fd94c076e1e1b1ee40a309be03080b75f413e8997cddcf401a118879863388"}, - {file = "Cython-3.0.8-cp312-cp312-win32.whl", hash = "sha256:85077915a93e359a9b920280d214dc0cf8a62773e1f3d7d30fab8ea4daed670c"}, - {file = "Cython-3.0.8-cp312-cp312-win_amd64.whl", hash = "sha256:0cb2dcc565c7851f75d496f724a384a790fab12d1b82461b663e66605bec429a"}, - {file = "Cython-3.0.8-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:870d2a0a7e3cbd5efa65aecdb38d715ea337a904ea7bb22324036e78fb7068e7"}, - {file = "Cython-3.0.8-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e8f2454128974905258d86534f4fd4f91d2f1343605657ecab779d80c9d6d5e"}, - {file = "Cython-3.0.8-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1949d6aa7bc792554bee2b67a9fe41008acbfe22f4f8df7b6ec7b799613a4b3"}, - {file = "Cython-3.0.8-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c9f2c6e1b8f3bcd6cb230bac1843f85114780bb8be8614855b1628b36bb510e0"}, - {file = "Cython-3.0.8-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:05d7eddc668ae7993643f32c7661f25544e791edb745758672ea5b1a82ecffa6"}, - {file = "Cython-3.0.8-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bfabe115deef4ada5d23c87bddb11289123336dcc14347011832c07db616dd93"}, - {file = "Cython-3.0.8-cp36-cp36m-win32.whl", hash = "sha256:0c38c9f0bcce2df0c3347285863621be904ac6b64c5792d871130569d893efd7"}, - {file = "Cython-3.0.8-cp36-cp36m-win_amd64.whl", hash = "sha256:6c46939c3983217d140999de7c238c3141f56b1ea349e47ca49cae899969aa2c"}, - {file = "Cython-3.0.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:115f0a50f752da6c99941b103b5cb090da63eb206abbc7c2ad33856ffc73f064"}, - {file = "Cython-3.0.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9c0f29246734561c90f36e70ed0506b61aa3d044e4cc4cba559065a2a741fae"}, - {file = "Cython-3.0.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ab75242869ff71e5665fe5c96f3378e79e792fa3c11762641b6c5afbbbbe026"}, - {file = "Cython-3.0.8-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6717c06e9cfc6c1df18543cd31a21f5d8e378a40f70c851fa2d34f0597037abc"}, - {file = "Cython-3.0.8-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:9d3f74388db378a3c6fd06e79a809ed98df3f56484d317b81ee762dbf3c263e0"}, - {file = "Cython-3.0.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ae7ac561fd8253a9ae96311e91d12af5f701383564edc11d6338a7b60b285a6f"}, - {file = "Cython-3.0.8-cp37-cp37m-win32.whl", hash = "sha256:97b2a45845b993304f1799664fa88da676ee19442b15fdcaa31f9da7e1acc434"}, - {file = "Cython-3.0.8-cp37-cp37m-win_amd64.whl", hash = "sha256:9e2be2b340fea46fb849d378f9b80d3c08ff2e81e2bfbcdb656e2e3cd8c6b2dc"}, - {file = "Cython-3.0.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2cde23c555470db3f149ede78b518e8274853745289c956a0e06ad8d982e4db9"}, - {file = "Cython-3.0.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7990ca127e1f1beedaf8fc8bf66541d066ef4723ad7d8d47a7cbf842e0f47580"}, - {file = "Cython-3.0.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b983c8e6803f016146c26854d9150ddad5662960c804ea7f0c752c9266752f0"}, - {file = "Cython-3.0.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a973268d7ca1a2bdf78575e459a94a78e1a0a9bb62a7db0c50041949a73b02ff"}, - {file = "Cython-3.0.8-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:61a237bc9dd23c7faef0fcfce88c11c65d0c9bb73c74ccfa408b3a012073c20e"}, - {file = "Cython-3.0.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:3a3d67f079598af49e90ff9655bf85bd358f093d727eb21ca2708f467c489cae"}, - {file = "Cython-3.0.8-cp38-cp38-win32.whl", hash = "sha256:17a642bb01a693e34c914106566f59844b4461665066613913463a719e0dd15d"}, - {file = "Cython-3.0.8-cp38-cp38-win_amd64.whl", hash = "sha256:2cdfc32252f3b6dc7c94032ab744dcedb45286733443c294d8f909a4854e7f83"}, - {file = "Cython-3.0.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fa97893d99385386925d00074654aeae3a98867f298d1e12ceaf38a9054a9bae"}, - {file = "Cython-3.0.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f05c0bf9d085c031df8f583f0d506aa3be1692023de18c45d0aaf78685bbb944"}, - {file = "Cython-3.0.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de892422582f5758bd8de187e98ac829330ec1007bc42c661f687792999988a7"}, - {file = "Cython-3.0.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:314f2355a1f1d06e3c431eaad4708cf10037b5e91e4b231d89c913989d0bdafd"}, - {file = "Cython-3.0.8-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:78825a3774211e7d5089730f00cdf7f473042acc9ceb8b9eeebe13ed3a5541de"}, - {file = "Cython-3.0.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:df8093deabc55f37028190cf5e575c26aad23fc673f34b85d5f45076bc37ce39"}, - {file = "Cython-3.0.8-cp39-cp39-win32.whl", hash = "sha256:1aca1b97e0095b3a9a6c33eada3f661a4ed0d499067d121239b193e5ba3bb4f0"}, - {file = "Cython-3.0.8-cp39-cp39-win_amd64.whl", hash = "sha256:16873d78be63bd38ffb759da7ab82814b36f56c769ee02b1d5859560e4c3ac3c"}, - {file = "Cython-3.0.8-py2.py3-none-any.whl", hash = "sha256:171b27051253d3f9108e9759e504ba59ff06e7f7ba944457f94deaf9c21bf0b6"}, - {file = "Cython-3.0.8.tar.gz", hash = "sha256:8333423d8fd5765e7cceea3a9985dd1e0a5dfeb2734629e1a2ed2d6233d39de6"}, + {file = "Cython-3.0.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:296bd30d4445ac61b66c9d766567f6e81a6e262835d261e903c60c891a6729d3"}, + {file = "Cython-3.0.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f496b52845cb45568a69d6359a2c335135233003e708ea02155c10ce3548aa89"}, + {file = "Cython-3.0.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:858c3766b9aa3ab8a413392c72bbab1c144a9766b7c7bfdef64e2e414363fa0c"}, + {file = "Cython-3.0.9-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c0eb1e6ef036028a52525fd9a012a556f6dd4788a0e8755fe864ba0e70cde2ff"}, + {file = "Cython-3.0.9-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c8191941073ea5896321de3c8c958fd66e5f304b0cd1f22c59edd0b86c4dd90d"}, + {file = "Cython-3.0.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e32b016030bc72a8a22a1f21f470a2f57573761a4f00fbfe8347263f4fbdb9f1"}, + {file = "Cython-3.0.9-cp310-cp310-win32.whl", hash = "sha256:d6f3ff1cd6123973fe03e0fb8ee936622f976c0c41138969975824d08886572b"}, + {file = "Cython-3.0.9-cp310-cp310-win_amd64.whl", hash = "sha256:56f3b643dbe14449248bbeb9a63fe3878a24256664bc8c8ef6efd45d102596d8"}, + {file = "Cython-3.0.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:35e6665a20d6b8a152d72b7fd87dbb2af6bb6b18a235b71add68122d594dbd41"}, + {file = "Cython-3.0.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f92f4960c40ad027bd8c364c50db11104eadc59ffeb9e5b7f605ca2f05946e20"}, + {file = "Cython-3.0.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38df37d0e732fbd9a2fef898788492e82b770c33d1e4ed12444bbc8a3b3f89c0"}, + {file = "Cython-3.0.9-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad7fd88ebaeaf2e76fd729a8919fae80dab3d6ac0005e28494261d52ff347a8f"}, + {file = "Cython-3.0.9-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1365d5f76bf4d19df3d19ce932584c9bb76e9fb096185168918ef9b36e06bfa4"}, + {file = "Cython-3.0.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c232e7f279388ac9625c3e5a5a9f0078a9334959c5d6458052c65bbbba895e1e"}, + {file = "Cython-3.0.9-cp311-cp311-win32.whl", hash = "sha256:357e2fad46a25030b0c0496487e01a9dc0fdd0c09df0897f554d8ba3c1bc4872"}, + {file = "Cython-3.0.9-cp311-cp311-win_amd64.whl", hash = "sha256:1315aee506506e8d69cf6631d8769e6b10131fdcc0eb66df2698f2a3ddaeeff2"}, + {file = "Cython-3.0.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:157973807c2796addbed5fbc4d9c882ab34bbc60dc297ca729504901479d5df7"}, + {file = "Cython-3.0.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00b105b5d050645dd59e6767bc0f18b48a4aa11c85f42ec7dd8181606f4059e3"}, + {file = "Cython-3.0.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac5536d09bef240cae0416d5a703d298b74c7bbc397da803ac9d344e732d4369"}, + {file = "Cython-3.0.9-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09c44501d476d16aaa4cbc29c87f8c0f54fc20e69b650d59cbfa4863426fc70c"}, + {file = "Cython-3.0.9-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:cc9c3b9f20d8e298618e5ccd32083ca386e785b08f9893fbec4c50b6b85be772"}, + {file = "Cython-3.0.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a30d96938c633e3ec37000ac3796525da71254ef109e66bdfd78f29891af6454"}, + {file = "Cython-3.0.9-cp312-cp312-win32.whl", hash = "sha256:757ca93bdd80702546df4d610d2494ef2e74249cac4d5ba9464589fb464bd8a3"}, + {file = "Cython-3.0.9-cp312-cp312-win_amd64.whl", hash = "sha256:1dc320a9905ab95414013f6de805efbff9e17bb5fb3b90bbac533f017bec8136"}, + {file = "Cython-3.0.9-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4ae349960ebe0da0d33724eaa7f1eb866688fe5434cc67ce4dbc06d6a719fbfc"}, + {file = "Cython-3.0.9-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63d2537bf688247f76ded6dee28ebd26274f019309aef1eb4f2f9c5c482fde2d"}, + {file = "Cython-3.0.9-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36f5a2dfc724bea1f710b649f02d802d80fc18320c8e6396684ba4a48412445a"}, + {file = "Cython-3.0.9-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:deaf4197d4b0bcd5714a497158ea96a2bd6d0f9636095437448f7e06453cc83d"}, + {file = "Cython-3.0.9-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:000af6deb7412eb7ac0c635ff5e637fb8725dd0a7b88cc58dfc2b3de14e701c4"}, + {file = "Cython-3.0.9-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:15c7f5c2d35bed9aa5f2a51eaac0df23ae72f2dbacf62fc672dd6bfaa75d2d6f"}, + {file = "Cython-3.0.9-cp36-cp36m-win32.whl", hash = "sha256:f49aa4970cd3bec66ac22e701def16dca2a49c59cceba519898dd7526e0be2c0"}, + {file = "Cython-3.0.9-cp36-cp36m-win_amd64.whl", hash = "sha256:4558814fa025b193058d42eeee498a53d6b04b2980d01339fc2444b23fd98e58"}, + {file = "Cython-3.0.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:539cd1d74fd61f6cfc310fa6bbbad5adc144627f2b7486a07075d4e002fd6aad"}, + {file = "Cython-3.0.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3232926cd406ee02eabb732206f6e882c3aed9d58f0fea764013d9240405bcf"}, + {file = "Cython-3.0.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33b6ac376538a7fc8c567b85d3c71504308a9318702ec0485dd66c059f3165cb"}, + {file = "Cython-3.0.9-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2cc92504b5d22ac66031ffb827bd3a967fc75a5f0f76ab48bce62df19be6fdfd"}, + {file = "Cython-3.0.9-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:22b8fae756c5c0d8968691bed520876de452f216c28ec896a00739a12dba3bd9"}, + {file = "Cython-3.0.9-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:9cda0d92a09f3520f29bd91009f1194ba9600777c02c30c6d2d4ac65fb63e40d"}, + {file = "Cython-3.0.9-cp37-cp37m-win32.whl", hash = "sha256:ec612418490941ed16c50c8d3784c7bdc4c4b2a10c361259871790b02ec8c1db"}, + {file = "Cython-3.0.9-cp37-cp37m-win_amd64.whl", hash = "sha256:976c8d2bedc91ff6493fc973d38b2dc01020324039e2af0e049704a8e1b22936"}, + {file = "Cython-3.0.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5055988b007c92256b6e9896441c3055556038c3497fcbf8c921a6c1fce90719"}, + {file = "Cython-3.0.9-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9360606d964c2d0492a866464efcf9d0a92715644eede3f6a2aa696de54a137"}, + {file = "Cython-3.0.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02c6e809f060bed073dc7cba1648077fe3b68208863d517c8b39f3920eecf9dd"}, + {file = "Cython-3.0.9-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:95ed792c966f969cea7489c32ff90150b415c1f3567db8d5a9d489c7c1602dac"}, + {file = "Cython-3.0.9-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8edd59d22950b400b03ca78d27dc694d2836a92ef0cac4f64cb4b2ff902f7e25"}, + {file = "Cython-3.0.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4cf0ed273bf60e97922fcbbdd380c39693922a597760160b4b4355e6078ca188"}, + {file = "Cython-3.0.9-cp38-cp38-win32.whl", hash = "sha256:5eb9bd4ae12ebb2bc79a193d95aacf090fbd8d7013e11ed5412711650cb34934"}, + {file = "Cython-3.0.9-cp38-cp38-win_amd64.whl", hash = "sha256:44457279da56e0f829bb1fc5a5dc0836e5d498dbcf9b2324f32f7cc9d2ec6569"}, + {file = "Cython-3.0.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c4b419a1adc2af43f4660e2f6eaf1e4fac2dbac59490771eb8ac3d6063f22356"}, + {file = "Cython-3.0.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f836192140f033b2319a0128936367c295c2b32e23df05b03b672a6015757ea"}, + {file = "Cython-3.0.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fd198c1a7f8e9382904d622cc0efa3c184605881fd5262c64cbb7168c4c1ec5"}, + {file = "Cython-3.0.9-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a274fe9ca5c53fafbcf5c8f262f8ad6896206a466f0eeb40aaf36a7951e957c0"}, + {file = "Cython-3.0.9-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:158c38360bbc5063341b1e78d3737f1251050f89f58a3df0d10fb171c44262be"}, + {file = "Cython-3.0.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8bf30b045f7deda0014b042c1b41c1d272facc762ab657529e3b05505888e878"}, + {file = "Cython-3.0.9-cp39-cp39-win32.whl", hash = "sha256:9a001fd95c140c94d934078544ff60a3c46aca2dc86e75a76e4121d3cd1f4b33"}, + {file = "Cython-3.0.9-cp39-cp39-win_amd64.whl", hash = "sha256:530c01c4aebba709c0ec9c7ecefe07177d0b9fd7ffee29450a118d92192ccbdf"}, + {file = "Cython-3.0.9-py2.py3-none-any.whl", hash = "sha256:bf96417714353c5454c2e3238fca9338599330cf51625cdc1ca698684465646f"}, + {file = "Cython-3.0.9.tar.gz", hash = "sha256:a2d354f059d1f055d34cfaa62c5b68bc78ac2ceab6407148d47fb508cf3ba4f3"}, ] [[package]] @@ -488,6 +488,37 @@ files = [ {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, ] +[[package]] +name = "dm-control" +version = "1.0.14" +description = "Continuous control environments and MuJoCo Python bindings." +optional = false +python-versions = ">=3.8" +files = [ + {file = "dm_control-1.0.14-py3-none-any.whl", hash = "sha256:883c63244a7ebf598700a97564ed19fffd3479ca79efd090aed881609cdb9fc6"}, + {file = "dm_control-1.0.14.tar.gz", hash = "sha256:def1ece747b6f175c581150826b50f1a6134086dab34f8f3fd2d088ea035cf3d"}, +] + +[package.dependencies] +absl-py = ">=0.7.0" +dm-env = "*" +dm-tree = "!=0.1.2" +glfw = "*" +labmaze = "*" +lxml = "*" +mujoco = ">=2.3.7" +numpy = ">=1.9.0" +protobuf = ">=3.19.4" +pyopengl = ">=3.1.4" +pyparsing = ">=3.0.0" +requests = "*" +scipy = "*" +setuptools = "!=50.0.0" +tqdm = "*" + +[package.extras] +hdf5 = ["h5py"] + [[package]] name = "dm-env" version = "1.6" @@ -584,43 +615,6 @@ files = [ {file = "einops-0.7.0.tar.gz", hash = "sha256:b2b04ad6081a3b227080c9bf5e3ace7160357ff03043cd66cc5b2319eb7031d1"}, ] -[[package]] -name = "etils" -version = "1.7.0" -description = "Collection of common python utils" -optional = false -python-versions = ">=3.10" -files = [ - {file = "etils-1.7.0-py3-none-any.whl", hash = "sha256:61af8f7c242171de15e22e5da02d527cb9e677d11f8bcafe18fcc3548eee3e60"}, - {file = "etils-1.7.0.tar.gz", hash = "sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350"}, -] - -[package.dependencies] -fsspec = {version = "*", optional = true, markers = "extra == \"epath\""} -importlib_resources = {version = "*", optional = true, markers = "extra == \"epath\""} -typing_extensions = {version = "*", optional = true, markers = "extra == \"epy\""} -zipp = {version = "*", optional = true, markers = "extra == \"epath\""} - -[package.extras] -all = ["etils[array-types]", "etils[eapp]", "etils[ecolab]", "etils[edc]", "etils[enp]", "etils[epath-gcs]", "etils[epath-s3]", "etils[epath]", "etils[epy]", "etils[etqdm]", "etils[etree-dm]", "etils[etree-jax]", "etils[etree-tf]", "etils[etree]"] -array-types = ["etils[enp]"] -dev = ["chex", "dataclass_array", "optree", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-subtests", "pytest-xdist", "torch"] -docs = ["etils[all,dev]", "sphinx-apitree[ext]"] -eapp = ["absl-py", "etils[epy]", "simple_parsing"] -ecolab = ["etils[enp]", "etils[epy]", "etils[etree]", "jupyter", "mediapy", "numpy", "packaging", "protobuf"] -edc = ["etils[epy]"] -enp = ["etils[epy]", "numpy"] -epath = ["etils[epy]", "fsspec", "importlib_resources", "typing_extensions", "zipp"] -epath-gcs = ["etils[epath]", "gcsfs"] -epath-s3 = ["etils[epath]", "s3fs"] -epy = ["typing_extensions"] -etqdm = ["absl-py", "etils[epy]", "tqdm"] -etree = ["etils[array-types]", "etils[enp]", "etils[epy]", "etils[etqdm]"] -etree-dm = ["dm-tree", "etils[etree]"] -etree-jax = ["etils[etree]", "jax[cpu]"] -etree-tf = ["etils[etree]", "tensorflow"] -lazy-imports = ["etils[ecolab]"] - [[package]] name = "exceptiongroup" version = "1.2.0" @@ -846,13 +840,13 @@ numpy = ">=1.17.3" [[package]] name = "huggingface-hub" -version = "0.21.3" +version = "0.21.4" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.21.3-py3-none-any.whl", hash = "sha256:b183144336fdf2810a8c109822e0bb6ef1fd61c65da6fb60e8c3f658b7144016"}, - {file = "huggingface_hub-0.21.3.tar.gz", hash = "sha256:26a15b604e4fc7bad37c467b76456543ec849386cbca9cd7e1e135f53e500423"}, + {file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"}, + {file = "huggingface_hub-0.21.4.tar.gz", hash = "sha256:e1f4968c93726565a80edf6dc309763c7b546d0cfe79aa221206034d50155531"}, ] [package.dependencies] @@ -971,37 +965,22 @@ setuptools = "*" [[package]] name = "importlib-metadata" -version = "7.0.1" +version = "7.0.2" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.0.1-py3-none-any.whl", hash = "sha256:4805911c3a4ec7c3966410053e9ec6a1fecd629117df5adee56dfc9432a1081e"}, - {file = "importlib_metadata-7.0.1.tar.gz", hash = "sha256:f238736bb06590ae52ac1fab06a3a9ef1d8dce2b7a35b5ab329371d6c8f5d2cc"}, + {file = "importlib_metadata-7.0.2-py3-none-any.whl", hash = "sha256:f4bc4c0c070c490abf4ce96d715f68e95923320370efb66143df00199bb6c100"}, + {file = "importlib_metadata-7.0.2.tar.gz", hash = "sha256:198f568f3230878cb1b44fbd7975f87906c22336dba2e4a7f05278c281fbd792"}, ] [package.dependencies] zipp = ">=0.5" [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] - -[[package]] -name = "importlib-resources" -version = "6.1.2" -description = "Read resources from Python packages" -optional = false -python-versions = ">=3.8" -files = [ - {file = "importlib_resources-6.1.2-py3-none-any.whl", hash = "sha256:9a0a862501dc38b68adebc82970140c9e4209fc99601782925178f8386339938"}, - {file = "importlib_resources-6.1.2.tar.gz", hash = "sha256:308abf8474e2dba5f867d279237cd4076482c3de7104a40b41426370e891549b"}, -] - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] [[package]] name = "iniconfig" @@ -1031,6 +1010,50 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "labmaze" +version = "1.0.6" +description = "LabMaze: DeepMind Lab's text maze generator." +optional = false +python-versions = "*" +files = [ + {file = "labmaze-1.0.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b2ddef976dfd8d992b19cfa6c633f2eba7576d759c2082da534e3f727479a84a"}, + {file = "labmaze-1.0.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:157efaa93228c8ccce5cae337902dd652093e0fba9d3a0f6506e4bee272bb66f"}, + {file = "labmaze-1.0.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3ce98b9541c5fe6a306e411e7d018121dd646f2c9978d763fad86f9f30c5f57"}, + {file = "labmaze-1.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e6433bd49bc541791de8191040526fddfebb77151620eb04203453f43ee486a"}, + {file = "labmaze-1.0.6-cp310-cp310-win_amd64.whl", hash = "sha256:6a507fc35961f1b1479708e2716f65e0d0611cefb55f31a77be29ce2339b6fef"}, + {file = "labmaze-1.0.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:a0c2cb9dec971814ea9c5d7150af15fa3964482131fa969e0afb94bd224348af"}, + {file = "labmaze-1.0.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c6ba9538d819543f4be448d36b4926a3881e53646a2b331ebb5a1f353047d05"}, + {file = "labmaze-1.0.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70635d1cdb0147a02efb6b3f607a52cdc51723bc3dcc42717a0d4ef55fa0a987"}, + {file = "labmaze-1.0.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff472793238bd9b6dabea8094594d6074ad3c111455de3afcae72f6c40c6817e"}, + {file = "labmaze-1.0.6-cp311-cp311-win_amd64.whl", hash = "sha256:2317e65e12fa3d1abecda7e0488dab15456cee8a2e717a586bfc8f02a91579e7"}, + {file = "labmaze-1.0.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:e36b6fadcd78f22057b597c1c77823e806a0987b3bdfbf850e14b6b5b502075e"}, + {file = "labmaze-1.0.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d1a4f8de29c2c3d7f14163759b69cd3f237093b85334c983619c1db5403a223b"}, + {file = "labmaze-1.0.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a394f8bb857fcaa2884b809d63e750841c2662a106cfe8c045f2112d201ac7d5"}, + {file = "labmaze-1.0.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d17abb69d4dfc56183afb5c317e8b2eaca0587abb3aabd2326efd3143c81f4e"}, + {file = "labmaze-1.0.6-cp312-cp312-win_amd64.whl", hash = "sha256:5af997598cc46b1929d1c5a1febc32fd56c75874fe481a2a5982c65cee8450c9"}, + {file = "labmaze-1.0.6-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:a4c5bc6e56baa55ce63b97569afec2f80cab0f6b952752a131e1f83eed190a53"}, + {file = "labmaze-1.0.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3955f24fe5f708e1e97495b4cfe284b70ae4fd51be5e17b75a6fc04ffbd67bca"}, + {file = "labmaze-1.0.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed96ddc0bb8d66df36428c94db83949fd84a15867e8250763a4c5e3d82104c54"}, + {file = "labmaze-1.0.6-cp37-cp37m-win_amd64.whl", hash = "sha256:3bd0458a29e55aa09f146e28a168d2e00b8ccf19e2259a3f71154cfff3536b1d"}, + {file = "labmaze-1.0.6-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:33f5154edc83dff55a150e54b60c8582fdafc7ec45195049809cbcc01f5e8f34"}, + {file = "labmaze-1.0.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0971055ef2a5f7d8517fdc42b67c057093698f1eb911f46faa7018867b73fcc9"}, + {file = "labmaze-1.0.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de18d09680007302abf49111f3fe822d8435e4fbc4468b9ec07d50a78e267865"}, + {file = "labmaze-1.0.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f18126066db2218a52853c7dd490b4c3d8129fc22eb3a47eb23007524b911d53"}, + {file = "labmaze-1.0.6-cp38-cp38-win_amd64.whl", hash = "sha256:f9aef09a76877342bb4d634b7e05f43b038a49c4f34adfb8f1b8ac57c29472f2"}, + {file = "labmaze-1.0.6-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5dd28899418f1b8b1c7d1e1b40a4593150a7cfa95ca91e23860b9785b82cc0ee"}, + {file = "labmaze-1.0.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:965569f37ee33090b4d4b3aa5aa7c9dcc4f62e2ae5d761e7f73ec76fc9d8aa96"}, + {file = "labmaze-1.0.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05eccfa98c0e781bc9f939076ae600b2e25ca736e123f2a530606aedec3b531c"}, + {file = "labmaze-1.0.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bee8c94e0fb3fc2d8180214947245c1d74a3489349a9da90b868296e77a521e9"}, + {file = "labmaze-1.0.6-cp39-cp39-win_amd64.whl", hash = "sha256:d486e9ca3a335ad628e3bd48a09c42f1aa5f51040952ef0fe32507afedcd694b"}, + {file = "labmaze-1.0.6.tar.gz", hash = "sha256:2e8de7094042a77d6972f1965cf5c9e8f971f1b34d225752f343190a825ebe73"}, +] + +[package.dependencies] +absl-py = "*" +numpy = ">=1.8.0" +setuptools = "!=50.0.0" + [[package]] name = "lazy-loader" version = "0.3" @@ -1076,6 +1099,99 @@ files = [ {file = "llvmlite-0.42.0.tar.gz", hash = "sha256:f92b09243c0cc3f457da8b983f67bd8e1295d0f5b3746c7a1861d7a99403854a"}, ] +[[package]] +name = "lxml" +version = "5.1.0" +description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." +optional = false +python-versions = ">=3.6" +files = [ + {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"}, + {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"}, + {file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"}, + {file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"}, + {file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2befa20a13f1a75c751f47e00929fb3433d67eb9923c2c0b364de449121f447c"}, + {file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22b7ee4c35f374e2c20337a95502057964d7e35b996b1c667b5c65c567d2252a"}, + {file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:bf8443781533b8d37b295016a4b53c1494fa9a03573c09ca5104550c138d5c05"}, + {file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"}, + {file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"}, + {file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"}, + {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:14deca1460b4b0f6b01f1ddc9557704e8b365f55c63070463f6c18619ebf964f"}, + {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"}, + {file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"}, + {file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"}, + {file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af8920ce4a55ff41167ddbc20077f5698c2e710ad3353d32a07d3264f3a2021e"}, + {file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7cfced4a069003d8913408e10ca8ed092c49a7f6cefee9bb74b6b3e860683b45"}, + {file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9e5ac3437746189a9b4121db2a7b86056ac8786b12e88838696899328fc44bb2"}, + {file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"}, + {file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"}, + {file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"}, + {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9aa543980ab1fbf1720969af1d99095a548ea42e00361e727c58a40832439114"}, + {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"}, + {file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"}, + {file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"}, + {file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16dd953fb719f0ffc5bc067428fc9e88f599e15723a85618c45847c96f11f431"}, + {file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16018f7099245157564d7148165132c70adb272fb5a17c048ba70d9cc542a1a1"}, + {file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:82cd34f1081ae4ea2ede3d52f71b7be313756e99b4b5f829f89b12da552d3aa3"}, + {file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:19a1bc898ae9f06bccb7c3e1dfd73897ecbbd2c96afe9095a6026016e5ca97b8"}, + {file = "lxml-5.1.0-cp312-cp312-win32.whl", hash = "sha256:13521a321a25c641b9ea127ef478b580b5ec82aa2e9fc076c86169d161798b01"}, + {file = "lxml-5.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:1ad17c20e3666c035db502c78b86e58ff6b5991906e55bdbef94977700c72623"}, + {file = "lxml-5.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:24ef5a4631c0b6cceaf2dbca21687e29725b7c4e171f33a8f8ce23c12558ded1"}, + {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d2900b7f5318bc7ad8631d3d40190b95ef2aa8cc59473b73b294e4a55e9f30f"}, + {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:601f4a75797d7a770daed8b42b97cd1bb1ba18bd51a9382077a6a247a12aa38d"}, + {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4b68c961b5cc402cbd99cca5eb2547e46ce77260eb705f4d117fd9c3f932b95"}, + {file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:afd825e30f8d1f521713a5669b63657bcfe5980a916c95855060048b88e1adb7"}, + {file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:262bc5f512a66b527d026518507e78c2f9c2bd9eb5c8aeeb9f0eb43fcb69dc67"}, + {file = "lxml-5.1.0-cp36-cp36m-win32.whl", hash = "sha256:e856c1c7255c739434489ec9c8aa9cdf5179785d10ff20add308b5d673bed5cd"}, + {file = "lxml-5.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:c7257171bb8d4432fe9d6fdde4d55fdbe663a63636a17f7f9aaba9bcb3153ad7"}, + {file = "lxml-5.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b9e240ae0ba96477682aa87899d94ddec1cc7926f9df29b1dd57b39e797d5ab5"}, + {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a96f02ba1bcd330807fc060ed91d1f7a20853da6dd449e5da4b09bfcc08fdcf5"}, + {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e3898ae2b58eeafedfe99e542a17859017d72d7f6a63de0f04f99c2cb125936"}, + {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61c5a7edbd7c695e54fca029ceb351fc45cd8860119a0f83e48be44e1c464862"}, + {file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3aeca824b38ca78d9ee2ab82bd9883083d0492d9d17df065ba3b94e88e4d7ee6"}, + {file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"}, + {file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"}, + {file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"}, + {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae15347a88cf8af0949a9872b57a320d2605ae069bcdf047677318bc0bba45b1"}, + {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"}, + {file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"}, + {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"}, + {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"}, + {file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:98f3f020a2b736566c707c8e034945c02aa94e124c24f77ca097c446f81b01f1"}, + {file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"}, + {file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"}, + {file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"}, + {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:52427a7eadc98f9e62cb1368a5079ae826f94f05755d2d567d93ee1bc3ceb354"}, + {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"}, + {file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"}, + {file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"}, + {file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f8b0c78e7aac24979ef09b7f50da871c2de2def043d468c4b41f512d831e912"}, + {file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9bcf86dfc8ff3e992fed847c077bd875d9e0ba2fa25d859c3a0f0f76f07f0c8d"}, + {file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:49a9b4af45e8b925e1cd6f3b15bbba2c81e7dba6dce170c677c9cda547411e14"}, + {file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:280f3edf15c2a967d923bcfb1f8f15337ad36f93525828b40a0f9d6c2ad24890"}, + {file = "lxml-5.1.0-cp39-cp39-win32.whl", hash = "sha256:ed7326563024b6e91fef6b6c7a1a2ff0a71b97793ac33dbbcf38f6005e51ff6e"}, + {file = "lxml-5.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:8d7b4beebb178e9183138f552238f7e6613162a42164233e2bda00cb3afac58f"}, + {file = "lxml-5.1.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9bd0ae7cc2b85320abd5e0abad5ccee5564ed5f0cc90245d2f9a8ef330a8deae"}, + {file = "lxml-5.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8c1d679df4361408b628f42b26a5d62bd3e9ba7f0c0e7969f925021554755aa"}, + {file = "lxml-5.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2ad3a8ce9e8a767131061a22cd28fdffa3cd2dc193f399ff7b81777f3520e372"}, + {file = "lxml-5.1.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:304128394c9c22b6569eba2a6d98392b56fbdfbad58f83ea702530be80d0f9df"}, + {file = "lxml-5.1.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d74fcaf87132ffc0447b3c685a9f862ffb5b43e70ea6beec2fb8057d5d2a1fea"}, + {file = "lxml-5.1.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:8cf5877f7ed384dabfdcc37922c3191bf27e55b498fecece9fd5c2c7aaa34c33"}, + {file = "lxml-5.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:877efb968c3d7eb2dad540b6cabf2f1d3c0fbf4b2d309a3c141f79c7e0061324"}, + {file = "lxml-5.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f14a4fb1c1c402a22e6a341a24c1341b4a3def81b41cd354386dcb795f83897"}, + {file = "lxml-5.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:25663d6e99659544ee8fe1b89b1a8c0aaa5e34b103fab124b17fa958c4a324a6"}, + {file = "lxml-5.1.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8b9f19df998761babaa7f09e6bc169294eefafd6149aaa272081cbddc7ba4ca3"}, + {file = "lxml-5.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e53d7e6a98b64fe54775d23a7c669763451340c3d44ad5e3a3b48a1efbdc96f"}, + {file = "lxml-5.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c3cd1fc1dc7c376c54440aeaaa0dcc803d2126732ff5c6b68ccd619f2e64be4f"}, + {file = "lxml-5.1.0.tar.gz", hash = "sha256:3eea6ed6e6c918e468e693c41ef07f3c3acc310b70ddd9cc72d9ef84bc9564ca"}, +] + +[package.extras] +cssselect = ["cssselect (>=0.7)"] +html5 = ["html5lib"] +htmlsoup = ["BeautifulSoup4"] +source = ["Cython (>=3.0.7)"] + [[package]] name = "markupsafe" version = "2.1.5" @@ -1188,42 +1304,40 @@ tests = ["pytest (>=4.6)"] [[package]] name = "mujoco" -version = "3.1.2" +version = "2.3.7" description = "MuJoCo Physics Simulator" optional = false python-versions = ">=3.8" files = [ - {file = "mujoco-3.1.2-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:fe6b3542695a5363f348ee45625b3492734f29cdc9f493ca25eae719f974370e"}, - {file = "mujoco-3.1.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f07e2d1f01f1401f1a503187016f8c017d9402618c659e1482243640a1e11288"}, - {file = "mujoco-3.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:93863eccc9d77d96ce62dda2a6f61cbd880379e8d774f802568d64b9613fce39"}, - {file = "mujoco-3.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3586c642390c16fef58b01a86071cec6814c471586e2f4115c3733c4aec64fb7"}, - {file = "mujoco-3.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0da77394c664945b78f199c627b609fe091ec0c4641b9d8f713637344a17821a"}, - {file = "mujoco-3.1.2-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:b6f12904d0478c191e4770ecf9006e20953f0488a2411a8ddc62592721c136dc"}, - {file = "mujoco-3.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f69b8d42b50c10f8d12df4948fc9d4dd6706841e7b163c1d7ce83448965acb1c"}, - {file = "mujoco-3.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10119e39b1f45fb76b18bea242fea1d6ccf4b2285f8bd5e2cb1e2cbdeb69bdcd"}, - {file = "mujoco-3.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a65868506dd45dddfe7be84857e57b49bc102334fc0439aa848a4d4d285d89b"}, - {file = "mujoco-3.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:92bc73972e39539f23a05bb411c45f9be17191fe01171ac15ffafed381ee4366"}, - {file = "mujoco-3.1.2-cp312-cp312-macosx_10_16_x86_64.whl", hash = "sha256:835d6b64ca4dc2f6a83291275fd48bd83edc888039d247958bf5b2c759db4340"}, - {file = "mujoco-3.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ce94ca3cf14fc519981674c5b85f1055356dcdcd63bbc0ec6c340084438f27f"}, - {file = "mujoco-3.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:250d9de4bd0d31fa4165faf01a1f838c429434f1263faacd95b977580f24eae7"}, - {file = "mujoco-3.1.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ea009d10bbf0aba9bc835f051d25f07a2c3edbaa06627ac2348766a1f3760b9"}, - {file = "mujoco-3.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:a0460d2ebdad4926f48b8c774da473e011c3b3afd0ccb6b6be1087b788c34011"}, - {file = "mujoco-3.1.2-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:4ca7cae89e258a338e02229edcf8f177b459ac5e9f859ffffa07fc2c9fcfb6aa"}, - {file = "mujoco-3.1.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:33b4fe9b5f891b29ef0fc2b0b975bc3a8a4b87774eecaf4364a83ddc6a7762ba"}, - {file = "mujoco-3.1.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ed230980f33bafaf1fa8b32ef25b82b069a245de15ee6ce7127e7e054cfad16"}, - {file = "mujoco-3.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41cc610ac40f325c9d49d9885ac6cb61822ed938f6c23cb183b261a7a28472ca"}, - {file = "mujoco-3.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:90a172b904a6ca8e6a1be80ab7c393aaff7592843a2a6853a4f97a9204031c41"}, - {file = "mujoco-3.1.2-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:93201291a0c5b573b4cbb19a6b08c99673f9fba167f402174eae5ffa23066d24"}, - {file = "mujoco-3.1.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0398985bb28c2686cdeeaf4ded46e602a49ec12115ac77474144ca940e5261c5"}, - {file = "mujoco-3.1.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2e76b5cb07ab3088c81966ac774d573df027fa5f4e78c20953a547528a2a698"}, - {file = "mujoco-3.1.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd5c3f4ae858e812cb3f03332693bcdc343b2bce55b164523acf52dea2736c9e"}, - {file = "mujoco-3.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:ca25ff2646b06609526ef8681c0e123cd854a53c9ff23cb91dd5058a2794dab4"}, - {file = "mujoco-3.1.2.tar.gz", hash = "sha256:53530bc1a91903f3fd4b1e99818cc38fbd9911700db29b2c9fc839f23bfacbb8"}, + {file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"}, + {file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a934315f858a4e0c4b90a682fde519471cfdd7baa64435179da8cd20d4ae3f99"}, + {file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:36513024330f88b5f9a43558efef5692b33599bffd5141029b690a27918ffcbe"}, + {file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d4eede8ba8210fbd3d3cd1dbf69e24dd1541aa74c5af5b8adbbbf65504b6dba"}, + {file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab85fafc9d5a091c712947573b7e694512d283876bf7f33ae3f8daad3a20c0db"}, + {file = "mujoco-2.3.7-cp310-cp310-win_amd64.whl", hash = "sha256:f8b7e13fef8c813d91b78f975ed0815157692777907ffa4b4be53a4edb75019b"}, + {file = "mujoco-2.3.7-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:779520216f72a8e370e3f0cdd71b45c3b7384c63331a3189194c930a3e7cff5c"}, + {file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9d4018053879016282d27ab7a91e292c72d44efb5a88553feacfe5b843dde103"}, + {file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:3149b16b8122ee62642474bfd2871064e8edc40235471cf5d84be3569afc0312"}, + {file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c08660a8d52ef3efde76095f0991e807703a950c1e882d2bcd984b9a846626f7"}, + {file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:426af8965f8636d94a0f75740c3024a62b3e585020ee817ef5208ec844a1ad94"}, + {file = "mujoco-2.3.7-cp311-cp311-win_amd64.whl", hash = "sha256:215415a8e98a4b50625beae859079d5e0810b2039e50420f0ba81763c34abb59"}, + {file = "mujoco-2.3.7-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:8b78d14f4c60cea3c58e046bd4de453fb5b9b33aca6a25fc91d39a53f3a5342a"}, + {file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5c6f5a51d6f537a4bf294cf73816f3a6384573f8f10a5452b044df2771412a96"}, + {file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:ea8911e6047f92d7d775701f37e4c093971b6def3160f01d0b6926e29a7e962e"}, + {file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7473a3de4dd1a8762d569ffb139196b4c5e7eca27d256df97b6cd4c66d2a09b2"}, + {file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40e7e2d8f93d2495ec74efec84e5118ecc6e1d85157a844789c73c9ac9a4e28e"}, + {file = "mujoco-2.3.7-cp38-cp38-win_amd64.whl", hash = "sha256:720bc228a2023b3b0ed6af78f5b0f8ea36867be321d473321555c57dbf6e4e5b"}, + {file = "mujoco-2.3.7-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:855e79686366442aa410246043b44f7d842d3900d68fe7e37feb42147db9d707"}, + {file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:98947f4a742d34d36f3c3f83e9167025bb0414bbaa4bd859b0673bdab9959963"}, + {file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:d42818f2ee5d1632dbce31d136ed5ff868db54b04e4e9aca0c5a3ac329f8a90f"}, + {file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9237e1ba14bced9449c31199e6d5be49547f3a4c99bc83b196af7ca45fd73b83"}, + {file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b728ea638245b150e2650c5433e6952e0ed3798c63e47e264574270caea2a3"}, + {file = "mujoco-2.3.7-cp39-cp39-win_amd64.whl", hash = "sha256:9c721a5042b99d948d5f0296a534bcce3f142c777c4d7642f503a539513f3912"}, + {file = "mujoco-2.3.7.tar.gz", hash = "sha256:422041f1ce37c6d151fbced1048df626837e94fe3cd9f813585907046336a7d0"}, ] [package.dependencies] absl-py = "*" -etils = {version = "*", extras = ["epath"]} glfw = "*" numpy = "*" pyopengl = "*" @@ -1519,13 +1633,13 @@ files = [ [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.3.101" +version = "12.4.99" description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, - {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"}, + {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"}, + {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"}, ] [[package]] @@ -1580,13 +1694,13 @@ numpy = [ [[package]] name = "packaging" -version = "23.2" +version = "24.0" description = "Core utilities for Python packages" optional = false python-versions = ">=3.7" files = [ - {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, - {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, + {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, + {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, ] [[package]] @@ -2016,6 +2130,20 @@ files = [ {file = "PyOpenGL-3.1.7.tar.gz", hash = "sha256:eef31a3888e6984fd4d8e6c9961b184c9813ca82604d37fe3da80eb000a76c86"}, ] +[[package]] +name = "pyparsing" +version = "3.1.2" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = false +python-versions = ">=3.6.8" +files = [ + {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, + {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + [[package]] name = "pysocks" version = "1.7.1" @@ -2030,13 +2158,13 @@ files = [ [[package]] name = "pytest" -version = "8.1.0" +version = "8.1.1" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-8.1.0-py3-none-any.whl", hash = "sha256:ee32db7af8de4629a455806befa90559f307424c07b8413ccfc30bf5b221dd7e"}, - {file = "pytest-8.1.0.tar.gz", hash = "sha256:f8fa04ab8f98d185113ae60ea6d79c22f8143b14bc1caeced44a0ab844928323"}, + {file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"}, + {file = "pytest-8.1.1.tar.gz", hash = "sha256:ac978141a75948948817d360297b7aae0fcb9d6ff6bc9ec6d514b85d5a65c044"}, ] [package.dependencies] @@ -2052,13 +2180,13 @@ testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygm [[package]] name = "python-dateutil" -version = "2.8.2" +version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ - {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, - {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, ] [package.dependencies] @@ -2483,13 +2611,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov", [[package]] name = "sentry-sdk" -version = "1.40.6" +version = "1.41.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.40.6.tar.gz", hash = "sha256:f143f3fb4bb57c90abef6e2ad06b5f6f02b2ca13e4060ec5c0549c7a9ccce3fa"}, - {file = "sentry_sdk-1.40.6-py2.py3-none-any.whl", hash = "sha256:becda09660df63e55f307570e9817c664392655a7328bbc414b507e9cb874c67"}, + {file = "sentry-sdk-1.41.0.tar.gz", hash = "sha256:4f2d6c43c07925d8cd10dfbd0970ea7cb784f70e79523cca9dbcd72df38e5a46"}, + {file = "sentry_sdk-1.41.0-py2.py3-none-any.whl", hash = "sha256:be4f8f4b29a80b6a3b71f0f31487beb9e296391da20af8504498a328befed53f"}, ] [package.dependencies] @@ -2515,7 +2643,7 @@ huey = ["huey (>=2)"] loguru = ["loguru (>=0.5)"] opentelemetry = ["opentelemetry-distro (>=0.35b0)"] opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"] -pure-eval = ["asttokens", "executing", "pure_eval"] +pure-eval = ["asttokens", "executing", "pure-eval"] pymongo = ["pymongo (>=3.1)"] pyspark = ["pyspark (>=2.4.4)"] quart = ["blinker (>=1.1)", "quart (>=0.16.1)"] @@ -2769,7 +2897,7 @@ tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures type = "git" url = "https://github.com/pytorch/tensordict" reference = "HEAD" -resolved_reference = "551331d83e2979dd4505db1f49895740e6e5c95f" +resolved_reference = "ed22554d6860731610df784b2f5d09f31d3dbc7a" [[package]] name = "termcolor" @@ -2888,13 +3016,13 @@ tensordict = ">=0.4.0" torch = ">=2.1.0" [package.extras] -all = ["ale-py", "atari-py", "dm-control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface-hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"] +all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"] atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"] checkpointing = ["torchsnapshot"] -dm-control = ["dm-control"] +dm-control = ["dm_control"] gym-continuous = ["gymnasium", "mujoco"] marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"] -offline-data = ["h5py", "huggingface-hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"] +offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"] rendering = ["moviepy"] tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"] utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"] @@ -3051,13 +3179,13 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [[package]] name = "wandb" -version = "0.16.3" +version = "0.16.4" description = "A CLI and library for interacting with the Weights & Biases API." optional = false python-versions = ">=3.7" files = [ - {file = "wandb-0.16.3-py3-none-any.whl", hash = "sha256:b8907ddd775c27dc6c12687386a86b5d6acf291060f9ae680bbc61cc8fc03237"}, - {file = "wandb-0.16.3.tar.gz", hash = "sha256:d789acda32053b18b7a160d0595837e45a3c8a79d25e1fe1f051875303f480ec"}, + {file = "wandb-0.16.4-py3-none-any.whl", hash = "sha256:bb9eb5aa2c2c85e11c76040c4271366f54d4975167aa6320ba86c3f2d97fe5fa"}, + {file = "wandb-0.16.4.tar.gz", hash = "sha256:8752c67d1347a4c29777e64dc1e1a742a66c5ecde03aebadf2b0d62183fa307c"}, ] [package.dependencies] @@ -3078,8 +3206,9 @@ async = ["httpx (>=0.23.0)"] aws = ["boto3"] azure = ["azure-identity", "azure-storage-blob"] gcp = ["google-cloud-storage"] +importers = ["filelock", "mlflow", "polars", "rich", "tenacity"] kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"] -launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "typing-extensions"] +launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "tomli", "typing-extensions"] media = ["bokeh", "moviepy", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit-pypi", "soundfile"] models = ["cloudpickle"] perf = ["orjson"] @@ -3088,13 +3217,13 @@ sweeps = ["sweeps (>=0.2.0)"] [[package]] name = "zarr" -version = "2.17.0" +version = "2.17.1" description = "An implementation of chunked, compressed, N-dimensional arrays for Python" optional = false python-versions = ">=3.9" files = [ - {file = "zarr-2.17.0-py3-none-any.whl", hash = "sha256:d287cb61019c4a0a0f386f76eeaa7f0b1160b1cb90cf96173a4b6cbc135df6e1"}, - {file = "zarr-2.17.0.tar.gz", hash = "sha256:6390a2b8af31babaab4c963efc45bf1da7f9500c9aafac193f84cf019a7c66b0"}, + {file = "zarr-2.17.1-py3-none-any.whl", hash = "sha256:e25df2741a6e92645f3890f30f3136d5b57a0f8f831094b024bbcab5f2797bc7"}, + {file = "zarr-2.17.1.tar.gz", hash = "sha256:564b3aa072122546fe69a0fa21736f466b20fad41754334b62619f088ce46261"}, ] [package.dependencies] @@ -3125,4 +3254,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "c4d83579aed1c8c2e54cad7c8ec81b95a09ab8faff74fc9a4cb20bd00e4ddec6" +content-hash = "3d82309a7b2388d774b56ceb6f6906ef0732d8cedda0d76cc84a30e239949be8" diff --git a/pyproject.toml b/pyproject.toml index 398f63ef..85af7f82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,13 +42,14 @@ mpmath = "^1.3.0" torch = "^2.2.1" tensordict = {git = "https://github.com/pytorch/tensordict"} torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"} -mujoco = "^3.1.2" +mujoco = "2.3.7" mujoco-py = "^2.1.2.14" gym = "^0.26.2" opencv-python = "^4.9.0.80" diffusers = "^0.26.3" torchvision = "^0.17.1" h5py = "^3.10.0" +dm-control = "1.0.14" [tool.poetry.group.dev.dependencies] diff --git a/sbatch.sh b/sbatch.sh index da52c472..cb5b285a 100644 --- a/sbatch.sh +++ b/sbatch.sh @@ -17,6 +17,7 @@ apptainer exec --nv \ ~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL source ~/.bashrc -conda activate fowm +#conda activate fowm +conda activate lerobot srun $CMD diff --git a/tests/data/aloha_sim_insertion_human/action.memmap b/tests/data/aloha_sim_insertion_human/action.memmap new file mode 100644 index 00000000..f64b2989 --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/action.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d789deddb081a9f4b626342391de8f48949d38fb5fdead87b5c0737b46c0877a +size 2800 diff --git a/tests/data/aloha_sim_insertion_human/episode.memmap b/tests/data/aloha_sim_insertion_human/episode.memmap new file mode 100644 index 00000000..af9fb07f --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/episode.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5 +size 400 diff --git a/tests/data/aloha_sim_insertion_human/frame_id.memmap b/tests/data/aloha_sim_insertion_human/frame_id.memmap new file mode 100644 index 00000000..dc2f585c --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/frame_id.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585 +size 400 diff --git a/tests/data/aloha_sim_insertion_human/meta.json b/tests/data/aloha_sim_insertion_human/meta.json new file mode 100644 index 00000000..2a0cf0a2 --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/meta.json @@ -0,0 +1 @@ +{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/next/done.memmap b/tests/data/aloha_sim_insertion_human/next/done.memmap new file mode 100644 index 00000000..44fd709f --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/next/done.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51 +size 50 diff --git a/tests/data/aloha_sim_insertion_human/next/meta.json b/tests/data/aloha_sim_insertion_human/next/meta.json new file mode 100644 index 00000000..3bfa9bd7 --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/next/meta.json @@ -0,0 +1 @@ +{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/next/observation/image/meta.json b/tests/data/aloha_sim_insertion_human/next/observation/image/meta.json new file mode 100644 index 00000000..cb29a5ab --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/next/observation/image/meta.json @@ -0,0 +1 @@ +{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/next/observation/image/top.memmap b/tests/data/aloha_sim_insertion_human/next/observation/image/top.memmap new file mode 100644 index 00000000..d3d8bd1c --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/next/observation/image/top.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c632e3cb06be729e5d673e3ecca1d6f6527b0f48cfe3dc03d7eea4f9eb3bbd7 +size 46080000 diff --git a/tests/data/aloha_sim_insertion_human/next/observation/meta.json b/tests/data/aloha_sim_insertion_human/next/observation/meta.json new file mode 100644 index 00000000..65ce1ca2 --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/next/observation/meta.json @@ -0,0 +1 @@ +{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/next/observation/state.memmap b/tests/data/aloha_sim_insertion_human/next/observation/state.memmap new file mode 100644 index 00000000..1f087a60 --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/next/observation/state.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e231f2e07e1cd030137ea2e938b570b112db2c694c6d21b37ceb8f8559e19088 +size 2800 diff --git a/tests/data/aloha_sim_insertion_human/observation/image/meta.json b/tests/data/aloha_sim_insertion_human/observation/image/meta.json new file mode 100644 index 00000000..cb29a5ab --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/observation/image/meta.json @@ -0,0 +1 @@ +{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/observation/image/top.memmap b/tests/data/aloha_sim_insertion_human/observation/image/top.memmap new file mode 100644 index 00000000..00c0b783 --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/observation/image/top.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1ba64c89f4fcf9135fe34c26abf582dd5f0d573506db5c96af3ffe40a52c818 +size 46080000 diff --git a/tests/data/aloha_sim_insertion_human/observation/meta.json b/tests/data/aloha_sim_insertion_human/observation/meta.json new file mode 100644 index 00000000..65ce1ca2 --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/observation/meta.json @@ -0,0 +1 @@ +{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/observation/state.memmap b/tests/data/aloha_sim_insertion_human/observation/state.memmap new file mode 100644 index 00000000..a1131179 --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/observation/state.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:85405686bc065c6ab6c915907920a0391a57cf097b74de058a8c30be0548ade5 +size 2800 diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth new file mode 100644 index 00000000..869d26cd Binary files /dev/null and b/tests/data/aloha_sim_insertion_human/stats.pth differ diff --git a/tests/data/aloha_sim_insertion_scripted/action.memmap b/tests/data/aloha_sim_insertion_scripted/action.memmap new file mode 100644 index 00000000..e4068b75 --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/action.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f5fe053b760e8471885b82c10f4a6ea40874098036337ae5cc300c4775546be +size 2800 diff --git a/tests/data/aloha_sim_insertion_scripted/episode.memmap b/tests/data/aloha_sim_insertion_scripted/episode.memmap new file mode 100644 index 00000000..af9fb07f --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/episode.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5 +size 400 diff --git a/tests/data/aloha_sim_insertion_scripted/frame_id.memmap b/tests/data/aloha_sim_insertion_scripted/frame_id.memmap new file mode 100644 index 00000000..dc2f585c --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/frame_id.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585 +size 400 diff --git a/tests/data/aloha_sim_insertion_scripted/meta.json b/tests/data/aloha_sim_insertion_scripted/meta.json new file mode 100644 index 00000000..2a0cf0a2 --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/meta.json @@ -0,0 +1 @@ +{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/next/done.memmap b/tests/data/aloha_sim_insertion_scripted/next/done.memmap new file mode 100644 index 00000000..44fd709f --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/next/done.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51 +size 50 diff --git a/tests/data/aloha_sim_insertion_scripted/next/meta.json b/tests/data/aloha_sim_insertion_scripted/next/meta.json new file mode 100644 index 00000000..3bfa9bd7 --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/next/meta.json @@ -0,0 +1 @@ +{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/image/meta.json b/tests/data/aloha_sim_insertion_scripted/next/observation/image/meta.json new file mode 100644 index 00000000..cb29a5ab --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/next/observation/image/meta.json @@ -0,0 +1 @@ +{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/image/top.memmap b/tests/data/aloha_sim_insertion_scripted/next/observation/image/top.memmap new file mode 100644 index 00000000..83911729 --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/next/observation/image/top.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:daed2bb10498ba2557983d0d7e89399882fea7585e7ceff910e23c621bfdbf88 +size 46080000 diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/meta.json b/tests/data/aloha_sim_insertion_scripted/next/observation/meta.json new file mode 100644 index 00000000..65ce1ca2 --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/next/observation/meta.json @@ -0,0 +1 @@ +{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/state.memmap b/tests/data/aloha_sim_insertion_scripted/next/observation/state.memmap new file mode 100644 index 00000000..aef69da0 --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/next/observation/state.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbad0302af70112ee312efe0eb0f44a2f1c8f6c5ef82ea4fb34625cdafbef057 +size 2800 diff --git a/tests/data/aloha_sim_insertion_scripted/observation/image/meta.json b/tests/data/aloha_sim_insertion_scripted/observation/image/meta.json new file mode 100644 index 00000000..cb29a5ab --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/observation/image/meta.json @@ -0,0 +1 @@ +{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/observation/image/top.memmap b/tests/data/aloha_sim_insertion_scripted/observation/image/top.memmap new file mode 100644 index 00000000..f9f0a759 --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/observation/image/top.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aba55ebb9dd004bf68444b9ebf024ed7713436099c06a0b8e541100ecbc69290 +size 46080000 diff --git a/tests/data/aloha_sim_insertion_scripted/observation/meta.json b/tests/data/aloha_sim_insertion_scripted/observation/meta.json new file mode 100644 index 00000000..65ce1ca2 --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/observation/meta.json @@ -0,0 +1 @@ +{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/observation/state.memmap b/tests/data/aloha_sim_insertion_scripted/observation/state.memmap new file mode 100644 index 00000000..91875055 --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/observation/state.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd4e7e14abf57561ca9839c910581266be90956e41bfb3bb21362ea0c321e77d +size 2800 diff --git a/tests/data/aloha_sim_insertion_scripted/stats.pth b/tests/data/aloha_sim_insertion_scripted/stats.pth new file mode 100644 index 00000000..0f3f9d2f Binary files /dev/null and b/tests/data/aloha_sim_insertion_scripted/stats.pth differ diff --git a/tests/data/aloha_sim_transfer_cube_human/action.memmap b/tests/data/aloha_sim_transfer_cube_human/action.memmap new file mode 100644 index 00000000..9b4fef33 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/action.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:14fed0eed3d529a8ac0dd25a6d41585020772d02f9137fc9d604713b2f0f7076 +size 2800 diff --git a/tests/data/aloha_sim_transfer_cube_human/episode.memmap b/tests/data/aloha_sim_transfer_cube_human/episode.memmap new file mode 100644 index 00000000..af9fb07f --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/episode.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5 +size 400 diff --git a/tests/data/aloha_sim_transfer_cube_human/frame_id.memmap b/tests/data/aloha_sim_transfer_cube_human/frame_id.memmap new file mode 100644 index 00000000..dc2f585c --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/frame_id.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585 +size 400 diff --git a/tests/data/aloha_sim_transfer_cube_human/meta.json b/tests/data/aloha_sim_transfer_cube_human/meta.json new file mode 100644 index 00000000..2a0cf0a2 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/meta.json @@ -0,0 +1 @@ +{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/next/done.memmap b/tests/data/aloha_sim_transfer_cube_human/next/done.memmap new file mode 100644 index 00000000..44fd709f --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/next/done.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51 +size 50 diff --git a/tests/data/aloha_sim_transfer_cube_human/next/meta.json b/tests/data/aloha_sim_transfer_cube_human/next/meta.json new file mode 100644 index 00000000..3bfa9bd7 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/next/meta.json @@ -0,0 +1 @@ +{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_human/next/observation/image/meta.json new file mode 100644 index 00000000..cb29a5ab --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/next/observation/image/meta.json @@ -0,0 +1 @@ +{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_human/next/observation/image/top.memmap new file mode 100644 index 00000000..cd2e7c06 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/next/observation/image/top.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f713ea7fc19e592ea409a5e0bdfde403e5b86f834cbabe3463b791e8437fafc +size 46080000 diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/meta.json b/tests/data/aloha_sim_transfer_cube_human/next/observation/meta.json new file mode 100644 index 00000000..65ce1ca2 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/next/observation/meta.json @@ -0,0 +1 @@ +{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_human/next/observation/state.memmap new file mode 100644 index 00000000..37feaad6 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/next/observation/state.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c103e2c9d63c9f7cf9645bd24d9a2c4e8e08825dc75e230ebc793b8f9c213b0 +size 2800 diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_human/observation/image/meta.json new file mode 100644 index 00000000..cb29a5ab --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/observation/image/meta.json @@ -0,0 +1 @@ +{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_human/observation/image/top.memmap new file mode 100644 index 00000000..1188590c --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/observation/image/top.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7dbf4aa01b184d0eaa21ea999078d7cff86e1ca484a109614176fdc49f1ee05c +size 46080000 diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/meta.json b/tests/data/aloha_sim_transfer_cube_human/observation/meta.json new file mode 100644 index 00000000..65ce1ca2 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/observation/meta.json @@ -0,0 +1 @@ +{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_human/observation/state.memmap new file mode 100644 index 00000000..9ef4cfd6 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/observation/state.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fa0b9c870d4615037b6fee9e9e85e54d84352e173f2c7c1035232272fe2a3dd +size 2800 diff --git a/tests/data/aloha_sim_transfer_cube_human/stats.pth b/tests/data/aloha_sim_transfer_cube_human/stats.pth new file mode 100644 index 00000000..0fa5667c Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_human/stats.pth differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/action.memmap b/tests/data/aloha_sim_transfer_cube_scripted/action.memmap new file mode 100644 index 00000000..8ac0c726 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/action.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0e199a82e2b7462e84406dbced5448a99f1dad9ce172771dfc3feb4b8597115 +size 2800 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/episode.memmap b/tests/data/aloha_sim_transfer_cube_scripted/episode.memmap new file mode 100644 index 00000000..af9fb07f --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/episode.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5 +size 400 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/frame_id.memmap b/tests/data/aloha_sim_transfer_cube_scripted/frame_id.memmap new file mode 100644 index 00000000..dc2f585c --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/frame_id.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585 +size 400 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/meta.json new file mode 100644 index 00000000..2a0cf0a2 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/meta.json @@ -0,0 +1 @@ +{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/done.memmap b/tests/data/aloha_sim_transfer_cube_scripted/next/done.memmap new file mode 100644 index 00000000..44fd709f --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/next/done.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51 +size 50 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/next/meta.json new file mode 100644 index 00000000..3bfa9bd7 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/next/meta.json @@ -0,0 +1 @@ +{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/meta.json new file mode 100644 index 00000000..cb29a5ab --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/meta.json @@ -0,0 +1 @@ +{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/top.memmap new file mode 100644 index 00000000..8e5f533e --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/top.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:30b44d38cc4d68e06c716a875d39cbdbeacbfdc1657d6366f58c279efd27c52b +size 46080000 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/meta.json new file mode 100644 index 00000000..65ce1ca2 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/meta.json @@ -0,0 +1 @@ +{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/state.memmap new file mode 100644 index 00000000..e88320d1 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/state.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f484e7ea4f5f612dd53ee2c0f7891b8f7b2168a54fc81941ac2f2447260c294 +size 2800 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/observation/image/meta.json new file mode 100644 index 00000000..cb29a5ab --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/observation/image/meta.json @@ -0,0 +1 @@ +{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_scripted/observation/image/top.memmap new file mode 100644 index 00000000..d415da0a --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/observation/image/top.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09600206f56cc5b52dfb896204b0044c4e830da368da141d7bd10e52181f6835 +size 46080000 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/observation/meta.json new file mode 100644 index 00000000..65ce1ca2 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/observation/meta.json @@ -0,0 +1 @@ +{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_scripted/observation/state.memmap new file mode 100644 index 00000000..be3436fb --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/observation/state.memmap @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60dcb547cf9a6372b78a455217a2408b6bece4371fba1df2a302b334d45c42a8 +size 2800 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/stats.pth b/tests/data/aloha_sim_transfer_cube_scripted/stats.pth new file mode 100644 index 00000000..b00756e6 Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_scripted/stats.pth differ diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth index c34e6a94..037e02f0 100644 Binary files a/tests/data/pusht/stats.pth and b/tests/data/pusht/stats.pth differ diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e63ae2c1..b7d1e6f8 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,8 +1,9 @@ import pytest +import torch from lerobot.common.datasets.factory import make_offline_buffer -from .utils import init_config +from .utils import DEVICE, init_config @pytest.mark.parametrize( @@ -12,12 +13,18 @@ from .utils import init_config # ("simxarm", "lift"), ("pusht", "pusht"), # TODO(aliberts): add aloha when dataset is available on hub - # ("aloha", "sim_insertion_human"), - # ("aloha", "sim_insertion_scripted"), - # ("aloha", "sim_transfer_cube_human"), - # ("aloha", "sim_transfer_cube_scripted"), + ("aloha", "sim_insertion_human"), + ("aloha", "sim_insertion_scripted"), + ("aloha", "sim_transfer_cube_human"), + ("aloha", "sim_transfer_cube_scripted"), ], ) def test_factory(env_name, dataset_id): - cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}"]) + cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]) offline_buffer = make_offline_buffer(cfg) + for key in offline_buffer.image_keys: + img = offline_buffer[0].get(key) + assert img.dtype == torch.float32 + # TODO(rcadene): we assume for now that image normalization takes place in the model + assert img.max() <= 1.0 + assert img.min() >= 0.0 diff --git a/tests/test_envs.py b/tests/test_envs.py index b51c441b..7776ba3c 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,12 +1,15 @@ +import os import pytest from tensordict import TensorDict +import torch from torchrl.envs.utils import check_env_specs, step_mdp +from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env from lerobot.common.envs.pusht.env import PushtEnv from lerobot.common.envs.simxarm import SimxarmEnv -from .utils import init_config +from .utils import DEVICE, init_config def print_spec_rollout(env): @@ -83,9 +86,24 @@ def test_pusht(from_pixels, pixels_only): [ # "simxarm", "pusht", + "aloha", ], ) def test_factory(env_name): - cfg = init_config(overrides=[f"env={env_name}"]) + cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"]) + + offline_buffer = make_offline_buffer(cfg) + env = make_env(cfg) + for key in offline_buffer.image_keys: + assert env.reset().get(key).dtype == torch.uint8 + check_env_specs(env) + + env = make_env(cfg, transform=offline_buffer.transform) + for key in offline_buffer.image_keys: + img = env.reset().get(key) + assert img.dtype == torch.float32 + # TODO(rcadene): we assume for now that image normalization takes place in the model + assert img.max() <= 1.0 + assert img.min() >= 0.0 check_env_specs(env) diff --git a/tests/test_policies.py b/tests/test_policies.py index 03f20bd0..f00429bc 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -2,7 +2,7 @@ import pytest from lerobot.common.policies.factory import make_policy -from .utils import init_config +from .utils import DEVICE, init_config @pytest.mark.parametrize( @@ -19,6 +19,7 @@ def test_factory(env_name, policy_name): overrides=[ f"env={env_name}", f"policy={policy_name}", + f"device={DEVICE}", ] ) policy = make_policy(cfg) diff --git a/tests/utils.py b/tests/utils.py index 40dc6de0..55709330 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,10 @@ +import os import hydra from hydra import compose, initialize CONFIG_PATH = "../lerobot/configs" +DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda") def init_config(config_name="default", overrides=None): hydra.core.global_hydra.GlobalHydra.instance().clear()