diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index e9edad05..2b161310 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -18,6 +18,12 @@ jobs:
env:
POETRY_VERSION: 1.8.1
DATA_DIR: tests/data
+ TMPDIR: ~/tmp
+ TEMP: ~/tmp
+ TMP: ~/tmp
+ PYOPENGL_PLATFORM: egl
+ MUJOCO_GL: egl
+ LEROBOT_TESTS_DEVICE: cpu
steps:
#----------------------------------------------
# check-out repo and set-up python
@@ -26,11 +32,13 @@ jobs:
uses: actions/checkout@v4
with:
lfs: true
+
- name: Set up python
id: setup-python
uses: actions/setup-python@v5
with:
python-version: '3.10'
+
#----------------------------------------------
# install & configure poetry
#----------------------------------------------
@@ -40,6 +48,7 @@ jobs:
with:
path: ~/.local # the path depends on the OS
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache
+
- name: Install Poetry
if: steps.restore-poetry-cache.outputs.cache-hit != 'true'
uses: snok/install-poetry@v1
@@ -47,6 +56,7 @@ jobs:
version: ${{ env.POETRY_VERSION }}
virtualenvs-create: true
installer-parallel: true
+
- name: Save cached Poetry installation
if: |
steps.restore-poetry-cache.outputs.cache-hit != 'true' &&
@@ -56,8 +66,10 @@ jobs:
with:
path: ~/.local # the path depends on the OS
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache
+
- name: Configure Poetry
run: poetry config virtualenvs.in-project true
+
#----------------------------------------------
# install dependencies
#----------------------------------------------
@@ -67,9 +79,21 @@ jobs:
with:
path: .venv
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
+
+ - name: Info
+ run: |
+ sudo du -sh /tmp
+ sudo df -h
+
- name: Install dependencies
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
- run: poetry install --no-interaction --no-root
+ run: |
+ mkdir ~/tmp
+ echo $TMPDIR
+ echo $TEMP
+ echo $TMP
+ poetry install --no-interaction --no-root
+
- name: Save cached venv
if: |
steps.restore-dependencies-cache.outputs.cache-hit != 'true' &&
@@ -79,11 +103,16 @@ jobs:
with:
path: .venv
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
+
+ - name: Install libegl1-mesa-dev (to use MUJOCO_GL=egl)
+ run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
+
#----------------------------------------------
# install project
#----------------------------------------------
- name: Install project
run: poetry install --no-interaction
+
#----------------------------------------------
# run tests
#----------------------------------------------
@@ -91,6 +120,7 @@ jobs:
run: |
source .venv/bin/activate
pytest tests
+
- name: Test train pusht end-to-end
run: |
source .venv/bin/activate
@@ -104,6 +134,7 @@ jobs:
save_model=true \
save_freq=1 \
hydra.run.dir=tests/outputs/
+
- name: Test eval pusht end-to-end
run: |
source .venv/bin/activate
diff --git a/README.md b/README.md
index b346a30f..5ccffc35 100644
--- a/README.md
+++ b/README.md
@@ -59,19 +59,10 @@ env=pusht
## TODO
-- [x] priority update doesn't match FOWM or original paper
-- [x] self.step=100000 should be updated at every step to adjust to the horizon of the planner
-- [ ] prefetch replay buffer to speedup training
-- [ ] parallelize env to speed up eval
-- [ ] clean checkpointing / loading
-- [ ] clean logging
-- [ ] clean config
-- [ ] clean hyperparameter tuning
-- [ ] add pusht
-- [ ] add aloha
-- [ ] add act
-- [ ] add diffusion
-- [ ] add aloha 2
+If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/users/Cadene/projects/1)
+
+Ask [Remi Cadene](re.cadene@gmail.com) for access if needed.
+
## Profile
diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py
index d6a51246..514fa038 100644
--- a/lerobot/common/datasets/abstract.py
+++ b/lerobot/common/datasets/abstract.py
@@ -81,7 +81,10 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def set_transform(self, transform):
if not isinstance(transform, Compose):
# required since torchrl calls `len(self._transform)` downstream
- self._transform = Compose(transform)
+ if isinstance(transform, list):
+ self._transform = Compose(*transform)
+ else:
+ self._transform = Compose(transform)
else:
self._transform = transform
diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py
index afc28b1c..3b53fed1 100644
--- a/lerobot/common/datasets/aloha.py
+++ b/lerobot/common/datasets/aloha.py
@@ -73,11 +73,11 @@ def download(data_dir, dataset_id):
data_dir.mkdir(parents=True, exist_ok=True)
- gdown.download_folder(FOLDER_URLS[dataset_id], output=data_dir)
+ gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
- gdown.download(EP48_URLS[dataset_id], output=data_dir / "episode_48.hdf5", fuzzy=True)
- gdown.download(EP49_URLS[dataset_id], output=data_dir / "episode_49.hdf5", fuzzy=True)
+ gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
+ gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
class AlohaExperienceReplay(AbstractExperienceReplay):
@@ -124,9 +124,6 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
- # def _is_downloaded(self) -> bool:
- # return False
-
def _download_and_preproc(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
if not raw_dir.is_dir():
diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py
index 29f40bc6..fd284ae2 100644
--- a/lerobot/common/datasets/factory.py
+++ b/lerobot/common/datasets/factory.py
@@ -5,7 +5,7 @@ from pathlib import Path
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
-from lerobot.common.envs.transforms import NormalizeTransform
+from lerobot.common.envs.transforms import NormalizeTransform, Prod
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
@@ -84,6 +84,16 @@ def make_offline_buffer(
prefetch=prefetch if isinstance(prefetch, int) else None,
)
+ if cfg.policy.name == "tdmpc":
+ img_keys = []
+ for key in offline_buffer.image_keys:
+ img_keys.append(("next", *key))
+ img_keys += offline_buffer.image_keys
+ else:
+ img_keys = offline_buffer.image_keys
+
+ transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
+
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
@@ -92,11 +102,10 @@ def make_offline_buffer(
in_keys = [("observation", "state"), ("action")]
if cfg.policy.name == "tdmpc":
- for key in offline_buffer.image_keys:
- # TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
- in_keys.append(key)
- # since we use next observations in tdmpc
- in_keys.append(("next", *key))
+ # TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
+ in_keys += img_keys
+ # TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
+ in_keys += [("next", *key) for key in img_keys]
in_keys.append(("next", "observation", "state"))
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
@@ -106,8 +115,11 @@ def make_offline_buffer(
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
- transform = NormalizeTransform(stats, in_keys, mode="min_max")
- offline_buffer.set_transform(transform)
+ # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
+ normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
+ transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
+
+ offline_buffer.set_transform(transforms)
if not overwrite_sampler:
index = torch.arange(0, offline_buffer.num_samples, 1)
diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py
index 784242cc..1d56850e 100644
--- a/lerobot/common/datasets/simxarm.py
+++ b/lerobot/common/datasets/simxarm.py
@@ -1,4 +1,5 @@
import pickle
+import zipfile
from pathlib import Path
from typing import Callable
@@ -15,6 +16,22 @@ from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractExperienceReplay
+def download():
+ raise NotImplementedError()
+ import gdown
+
+ url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
+ download_path = "data.zip"
+ gdown.download(url, download_path, quiet=False)
+ print("Extracting...")
+ with zipfile.ZipFile(download_path, "r") as zip_f:
+ for member in zip_f.namelist():
+ if member.startswith("data/xarm") and member.endswith(".pkl"):
+ print(member)
+ zip_f.extract(member=member)
+ Path(download_path).unlink()
+
+
class SimxarmExperienceReplay(AbstractExperienceReplay):
available_datasets = [
"xarm_lift_medium",
@@ -48,8 +65,8 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
)
def _download_and_preproc(self):
- # download
- # TODO(rcadene)
+ # TODO(rcadene): finish download
+ download()
dataset_path = self.data_dir / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
diff --git a/lerobot/common/envs/abstract.py b/lerobot/common/envs/abstract.py
new file mode 100644
index 00000000..0754fb76
--- /dev/null
+++ b/lerobot/common/envs/abstract.py
@@ -0,0 +1,80 @@
+import abc
+from collections import deque
+from typing import Optional
+
+from tensordict import TensorDict
+from torchrl.envs import EnvBase
+
+
+class AbstractEnv(EnvBase):
+ def __init__(
+ self,
+ task,
+ frame_skip: int = 1,
+ from_pixels: bool = False,
+ pixels_only: bool = False,
+ image_size=None,
+ seed=1337,
+ device="cpu",
+ num_prev_obs=1,
+ num_prev_action=0,
+ ):
+ super().__init__(device=device, batch_size=[])
+ self.task = task
+ self.frame_skip = frame_skip
+ self.from_pixels = from_pixels
+ self.pixels_only = pixels_only
+ self.image_size = image_size
+ self.num_prev_obs = num_prev_obs
+ self.num_prev_action = num_prev_action
+ self._rendering_hooks = []
+
+ if pixels_only:
+ assert from_pixels
+ if from_pixels:
+ assert image_size
+
+ self._make_env()
+ self._make_spec()
+ self._current_seed = self.set_seed(seed)
+
+ if self.num_prev_obs > 0:
+ self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
+ self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
+ if self.num_prev_action > 0:
+ raise NotImplementedError()
+ # self._prev_action_queue = deque(maxlen=self.num_prev_action)
+
+ def register_rendering_hook(self, func):
+ self._rendering_hooks.append(func)
+
+ def call_rendering_hooks(self):
+ for func in self._rendering_hooks:
+ func(self)
+
+ def reset_rendering_hooks(self):
+ self._rendering_hooks = []
+
+ @abc.abstractmethod
+ def render(self, mode="rgb_array", width=640, height=480):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _reset(self, tensordict: Optional[TensorDict] = None):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _step(self, tensordict: TensorDict):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _make_env(self):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _make_spec(self):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _set_seed(self, seed: Optional[int]):
+ raise NotImplementedError()
diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml
new file mode 100644
index 00000000..8002838c
--- /dev/null
+++ b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml
@@ -0,0 +1,59 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml
new file mode 100644
index 00000000..05249ad2
--- /dev/null
+++ b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml
@@ -0,0 +1,48 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml
new file mode 100644
index 00000000..511f7947
--- /dev/null
+++ b/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml
@@ -0,0 +1,53 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml
new file mode 100644
index 00000000..2d85a47c
--- /dev/null
+++ b/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml
@@ -0,0 +1,42 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lerobot/common/envs/aloha/assets/scene.xml b/lerobot/common/envs/aloha/assets/scene.xml
new file mode 100644
index 00000000..0f61b8a5
--- /dev/null
+++ b/lerobot/common/envs/aloha/assets/scene.xml
@@ -0,0 +1,38 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lerobot/common/envs/aloha/assets/tabletop.stl b/lerobot/common/envs/aloha/assets/tabletop.stl
new file mode 100644
index 00000000..ab35cdf7
Binary files /dev/null and b/lerobot/common/envs/aloha/assets/tabletop.stl differ
diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl
new file mode 100644
index 00000000..534c7af9
Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl differ
diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl
new file mode 100644
index 00000000..d6a492c2
Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl differ
diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl b/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl
new file mode 100644
index 00000000..d6df86be
Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl differ
diff --git a/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl b/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl
new file mode 100644
index 00000000..193014b6
Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl differ
diff --git a/lerobot/common/envs/aloha/assets/vx300s_1_base.stl b/lerobot/common/envs/aloha/assets/vx300s_1_base.stl
new file mode 100644
index 00000000..5a7efda2
Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_1_base.stl differ
diff --git a/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl b/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl
new file mode 100644
index 00000000..dc22aa7e
Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl differ
diff --git a/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl b/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl
new file mode 100644
index 00000000..111c586e
Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl differ
diff --git a/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl b/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl
new file mode 100644
index 00000000..8170d21c
Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl differ
diff --git a/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl b/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl
new file mode 100644
index 00000000..39581f83
Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl differ
diff --git a/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl b/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl
new file mode 100644
index 00000000..ab8423e9
Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl differ
diff --git a/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl b/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl
new file mode 100644
index 00000000..043db9ca
Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl differ
diff --git a/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl b/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl
new file mode 100644
index 00000000..36099b42
Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl differ
diff --git a/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl b/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl
new file mode 100644
index 00000000..eba3caa2
Binary files /dev/null and b/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl differ
diff --git a/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml b/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml
new file mode 100644
index 00000000..93037ab7
--- /dev/null
+++ b/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml
@@ -0,0 +1,17 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lerobot/common/envs/aloha/assets/vx300s_left.xml b/lerobot/common/envs/aloha/assets/vx300s_left.xml
new file mode 100644
index 00000000..3af6c235
--- /dev/null
+++ b/lerobot/common/envs/aloha/assets/vx300s_left.xml
@@ -0,0 +1,59 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lerobot/common/envs/aloha/assets/vx300s_right.xml b/lerobot/common/envs/aloha/assets/vx300s_right.xml
new file mode 100644
index 00000000..495df478
--- /dev/null
+++ b/lerobot/common/envs/aloha/assets/vx300s_right.xml
@@ -0,0 +1,59 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lerobot/common/envs/aloha/constants.py b/lerobot/common/envs/aloha/constants.py
new file mode 100644
index 00000000..e582e5f3
--- /dev/null
+++ b/lerobot/common/envs/aloha/constants.py
@@ -0,0 +1,163 @@
+from pathlib import Path
+
+### Simulation envs fixed constants
+DT = 0.02 # 0.02 ms -> 1/0.2 = 50 hz
+FPS = 50
+
+
+JOINTS = [
+ # absolute joint position
+ "left_arm_waist",
+ "left_arm_shoulder",
+ "left_arm_elbow",
+ "left_arm_forearm_roll",
+ "left_arm_wrist_angle",
+ "left_arm_wrist_rotate",
+ # normalized gripper position 0: close, 1: open
+ "left_arm_gripper",
+ # absolute joint position
+ "right_arm_waist",
+ "right_arm_shoulder",
+ "right_arm_elbow",
+ "right_arm_forearm_roll",
+ "right_arm_wrist_angle",
+ "right_arm_wrist_rotate",
+ # normalized gripper position 0: close, 1: open
+ "right_arm_gripper",
+]
+
+ACTIONS = [
+ # position and quaternion for end effector
+ "left_arm_waist",
+ "left_arm_shoulder",
+ "left_arm_elbow",
+ "left_arm_forearm_roll",
+ "left_arm_wrist_angle",
+ "left_arm_wrist_rotate",
+ # normalized gripper position (0: close, 1: open)
+ "left_arm_gripper",
+ "right_arm_waist",
+ "right_arm_shoulder",
+ "right_arm_elbow",
+ "right_arm_forearm_roll",
+ "right_arm_wrist_angle",
+ "right_arm_wrist_rotate",
+ # normalized gripper position (0: close, 1: open)
+ "right_arm_gripper",
+]
+
+
+START_ARM_POSE = [
+ 0,
+ -0.96,
+ 1.16,
+ 0,
+ -0.3,
+ 0,
+ 0.02239,
+ -0.02239,
+ 0,
+ -0.96,
+ 1.16,
+ 0,
+ -0.3,
+ 0,
+ 0.02239,
+ -0.02239,
+]
+
+ASSETS_DIR = Path(__file__).parent.resolve() / "assets" # note: absolute path
+
+# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
+MASTER_GRIPPER_POSITION_OPEN = 0.02417
+MASTER_GRIPPER_POSITION_CLOSE = 0.01244
+PUPPET_GRIPPER_POSITION_OPEN = 0.05800
+PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
+
+# Gripper joint limits (qpos[6])
+MASTER_GRIPPER_JOINT_OPEN = 0.3083
+MASTER_GRIPPER_JOINT_CLOSE = -0.6842
+PUPPET_GRIPPER_JOINT_OPEN = 1.4910
+PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
+
+MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
+
+############################ Helper functions ############################
+
+
+def normalize_master_gripper_position(x):
+ return (x - MASTER_GRIPPER_POSITION_CLOSE) / (
+ MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
+ )
+
+
+def normalize_puppet_gripper_position(x):
+ return (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
+ PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
+ )
+
+
+def unnormalize_master_gripper_position(x):
+ return x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
+
+
+def unnormalize_puppet_gripper_position(x):
+ return x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
+
+
+def convert_position_from_master_to_puppet(x):
+ return unnormalize_puppet_gripper_position(normalize_master_gripper_position(x))
+
+
+def normalizer_master_gripper_joint(x):
+ return (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+
+
+def normalize_puppet_gripper_joint(x):
+ return (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+
+
+def unnormalize_master_gripper_joint(x):
+ return x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
+
+
+def unnormalize_puppet_gripper_joint(x):
+ return x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
+
+
+def convert_join_from_master_to_puppet(x):
+ return unnormalize_puppet_gripper_joint(normalizer_master_gripper_joint(x))
+
+
+def normalize_master_gripper_velocity(x):
+ return x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
+
+
+def normalize_puppet_gripper_velocity(x):
+ return x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
+
+
+def convert_master_from_position_to_joint(x):
+ return (
+ normalize_master_gripper_position(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+ + MASTER_GRIPPER_JOINT_CLOSE
+ )
+
+
+def convert_master_from_joint_to_position(x):
+ return unnormalize_master_gripper_position(
+ (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+ )
+
+
+def convert_puppet_from_position_to_join(x):
+ return (
+ normalize_puppet_gripper_position(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+ + PUPPET_GRIPPER_JOINT_CLOSE
+ )
+
+
+def convert_puppet_from_joint_to_position(x):
+ return unnormalize_puppet_gripper_position(
+ (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+ )
diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py
new file mode 100644
index 00000000..1211a37a
--- /dev/null
+++ b/lerobot/common/envs/aloha/env.py
@@ -0,0 +1,311 @@
+import importlib
+import logging
+from collections import deque
+from typing import Optional
+
+import einops
+import numpy as np
+import torch
+from dm_control import mujoco
+from dm_control.rl import control
+from tensordict import TensorDict
+from torchrl.data.tensor_specs import (
+ BoundedTensorSpec,
+ CompositeSpec,
+ DiscreteTensorSpec,
+ UnboundedContinuousTensorSpec,
+)
+
+from lerobot.common.envs.abstract import AbstractEnv
+from lerobot.common.envs.aloha.constants import (
+ ACTIONS,
+ ASSETS_DIR,
+ DT,
+ JOINTS,
+)
+from lerobot.common.envs.aloha.tasks.sim import BOX_POSE, InsertionTask, TransferCubeTask
+from lerobot.common.envs.aloha.tasks.sim_end_effector import (
+ InsertionEndEffectorTask,
+ TransferCubeEndEffectorTask,
+)
+from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
+from lerobot.common.utils import set_seed
+
+_has_gym = importlib.util.find_spec("gym") is not None
+
+
+class AlohaEnv(AbstractEnv):
+ def __init__(
+ self,
+ task,
+ frame_skip: int = 1,
+ from_pixels: bool = False,
+ pixels_only: bool = False,
+ image_size=None,
+ seed=1337,
+ device="cpu",
+ num_prev_obs=1,
+ num_prev_action=0,
+ ):
+ super().__init__(
+ task=task,
+ frame_skip=frame_skip,
+ from_pixels=from_pixels,
+ pixels_only=pixels_only,
+ image_size=image_size,
+ seed=seed,
+ device=device,
+ num_prev_obs=num_prev_obs,
+ num_prev_action=num_prev_action,
+ )
+
+ def _make_env(self):
+ if not _has_gym:
+ raise ImportError("Cannot import gym.")
+
+ if not self.from_pixels:
+ raise NotImplementedError()
+
+ self._env = self._make_env_task(self.task)
+
+ def render(self, mode="rgb_array", width=640, height=480):
+ # TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
+ image = self._env.physics.render(height=height, width=width, camera_id="top")
+ return image
+
+ def _make_env_task(self, task_name):
+ # time limit is controlled by StepCounter in env factory
+ time_limit = float("inf")
+
+ if "sim_transfer_cube" in task_name:
+ xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
+ physics = mujoco.Physics.from_xml_path(str(xml_path))
+ task = TransferCubeTask(random=False)
+ elif "sim_insertion" in task_name:
+ xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
+ physics = mujoco.Physics.from_xml_path(str(xml_path))
+ task = InsertionTask(random=False)
+ elif "sim_end_effector_transfer_cube" in task_name:
+ raise NotImplementedError()
+ xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
+ physics = mujoco.Physics.from_xml_path(str(xml_path))
+ task = TransferCubeEndEffectorTask(random=False)
+ elif "sim_end_effector_insertion" in task_name:
+ raise NotImplementedError()
+ xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
+ physics = mujoco.Physics.from_xml_path(str(xml_path))
+ task = InsertionEndEffectorTask(random=False)
+ else:
+ raise NotImplementedError(task_name)
+
+ env = control.Environment(
+ physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False
+ )
+ return env
+
+ def _format_raw_obs(self, raw_obs):
+ if self.from_pixels:
+ image = torch.from_numpy(raw_obs["images"]["top"].copy())
+ image = einops.rearrange(image, "h w c -> c h w")
+ assert image.dtype == torch.uint8
+ obs = {"image": {"top": image}}
+
+ if not self.pixels_only:
+ obs["state"] = torch.from_numpy(raw_obs["qpos"]).type(torch.float32)
+ else:
+ # TODO(rcadene):
+ raise NotImplementedError()
+ # obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
+
+ return obs
+
+ def _reset(self, tensordict: Optional[TensorDict] = None):
+ td = tensordict
+ if td is None or td.is_empty():
+ # we need to handle seed iteration, since self._env.reset() rely an internal _seed.
+ self._current_seed += 1
+ self.set_seed(self._current_seed)
+
+ # TODO(rcadene): do not use global variable for this
+ if "sim_transfer_cube" in self.task:
+ BOX_POSE[0] = sample_box_pose() # used in sim reset
+ elif "sim_insertion" in self.task:
+ BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
+
+ raw_obs = self._env.reset()
+ # TODO(rcadene): add assert
+ # assert self._current_seed == self._env._seed
+
+ obs = self._format_raw_obs(raw_obs.observation)
+
+ if self.num_prev_obs > 0:
+ stacked_obs = {}
+ if "image" in obs:
+ self._prev_obs_image_queue = deque(
+ [obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
+ )
+ stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
+ if "state" in obs:
+ self._prev_obs_state_queue = deque(
+ [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
+ )
+ stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
+ obs = stacked_obs
+
+ td = TensorDict(
+ {
+ "observation": TensorDict(obs, batch_size=[]),
+ "done": torch.tensor([False], dtype=torch.bool),
+ },
+ batch_size=[],
+ )
+ else:
+ raise NotImplementedError()
+
+ self.call_rendering_hooks()
+ return td
+
+ def _step(self, tensordict: TensorDict):
+ td = tensordict
+ action = td["action"].numpy()
+ # step expects shape=(4,) so we pad if necessary
+ # TODO(rcadene): add info["is_success"] and info["success"] ?
+ sum_reward = 0
+
+ if action.ndim == 1:
+ action = einops.repeat(action, "c -> t c", t=self.frame_skip)
+ else:
+ if self.frame_skip > 1:
+ raise NotImplementedError()
+
+ num_action_steps = action.shape[0]
+ for i in range(num_action_steps):
+ _, reward, discount, raw_obs = self._env.step(action[i])
+ del discount # not used
+
+ # TOOD(rcadene): add an enum
+ success = done = reward == 4
+ sum_reward += reward
+ obs = self._format_raw_obs(raw_obs)
+
+ if self.num_prev_obs > 0:
+ stacked_obs = {}
+ if "image" in obs:
+ self._prev_obs_image_queue.append(obs["image"]["top"])
+ stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
+ if "state" in obs:
+ self._prev_obs_state_queue.append(obs["state"])
+ stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
+ obs = stacked_obs
+
+ self.call_rendering_hooks()
+
+ td = TensorDict(
+ {
+ "observation": TensorDict(obs, batch_size=[]),
+ "reward": torch.tensor([sum_reward], dtype=torch.float32),
+ # succes and done are true when coverage > self.success_threshold in env
+ "done": torch.tensor([done], dtype=torch.bool),
+ "success": torch.tensor([success], dtype=torch.bool),
+ },
+ batch_size=[],
+ )
+ return td
+
+ def _make_spec(self):
+ obs = {}
+ from omegaconf import OmegaConf
+
+ if self.from_pixels:
+ if isinstance(self.image_size, int):
+ image_shape = (3, self.image_size, self.image_size)
+ elif OmegaConf.is_list(self.image_size):
+ assert len(self.image_size) == 3 # c h w
+ assert self.image_size[0] == 3 # c is RGB
+ image_shape = tuple(self.image_size)
+ else:
+ raise ValueError(self.image_size)
+ if self.num_prev_obs > 0:
+ image_shape = (self.num_prev_obs + 1, *image_shape)
+
+ obs["image"] = {
+ "top": BoundedTensorSpec(
+ low=0,
+ high=255,
+ shape=image_shape,
+ dtype=torch.uint8,
+ device=self.device,
+ )
+ }
+ if not self.pixels_only:
+ state_shape = (len(JOINTS),)
+ if self.num_prev_obs > 0:
+ state_shape = (self.num_prev_obs + 1, *state_shape)
+
+ obs["state"] = UnboundedContinuousTensorSpec(
+ # TODO: add low and high bounds
+ shape=state_shape,
+ dtype=torch.float32,
+ device=self.device,
+ )
+ else:
+ # TODO(rcadene): add observation_space achieved_goal and desired_goal?
+ state_shape = (len(JOINTS),)
+ if self.num_prev_obs > 0:
+ state_shape = (self.num_prev_obs + 1, *state_shape)
+
+ obs["state"] = UnboundedContinuousTensorSpec(
+ # TODO: add low and high bounds
+ shape=state_shape,
+ dtype=torch.float32,
+ device=self.device,
+ )
+ self.observation_spec = CompositeSpec({"observation": obs})
+
+ # TODO(rcadene): valid when controling end effector?
+ # action_space = self._env.action_spec()
+ # self.action_spec = BoundedTensorSpec(
+ # low=action_space.minimum,
+ # high=action_space.maximum,
+ # shape=action_space.shape,
+ # dtype=torch.float32,
+ # device=self.device,
+ # )
+
+ # TODO(rcaene): add bounds (where are they????)
+ self.action_spec = BoundedTensorSpec(
+ shape=(len(ACTIONS)),
+ low=-1,
+ high=1,
+ dtype=torch.float32,
+ device=self.device,
+ )
+
+ self.reward_spec = UnboundedContinuousTensorSpec(
+ shape=(1,),
+ dtype=torch.float32,
+ device=self.device,
+ )
+
+ self.done_spec = CompositeSpec(
+ {
+ "done": DiscreteTensorSpec(
+ 2,
+ shape=(1,),
+ dtype=torch.bool,
+ device=self.device,
+ ),
+ "success": DiscreteTensorSpec(
+ 2,
+ shape=(1,),
+ dtype=torch.bool,
+ device=self.device,
+ ),
+ }
+ )
+
+ def _set_seed(self, seed: Optional[int]):
+ set_seed(seed)
+ # TODO(rcadene): seed the env
+ # self._env.seed(seed)
+ logging.warning("Aloha env is not seeded")
diff --git a/lerobot/common/envs/aloha/tasks/sim.py b/lerobot/common/envs/aloha/tasks/sim.py
new file mode 100644
index 00000000..ee1d0927
--- /dev/null
+++ b/lerobot/common/envs/aloha/tasks/sim.py
@@ -0,0 +1,219 @@
+import collections
+
+import numpy as np
+from dm_control.suite import base
+
+from lerobot.common.envs.aloha.constants import (
+ START_ARM_POSE,
+ normalize_puppet_gripper_position,
+ normalize_puppet_gripper_velocity,
+ unnormalize_puppet_gripper_position,
+)
+
+BOX_POSE = [None] # to be changed from outside
+
+"""
+Environment for simulated robot bi-manual manipulation, with joint position control
+Action space: [left_arm_qpos (6), # absolute joint position
+ left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
+ right_arm_qpos (6), # absolute joint position
+ right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
+
+Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
+ left_gripper_position (1), # normalized gripper position (0: close, 1: open)
+ right_arm_qpos (6), # absolute joint position
+ right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
+ "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
+ left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
+ right_arm_qvel (6), # absolute joint velocity (rad)
+ right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
+ "images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
+"""
+
+
+class BimanualViperXTask(base.Task):
+ def __init__(self, random=None):
+ super().__init__(random=random)
+
+ def before_step(self, action, physics):
+ left_arm_action = action[:6]
+ right_arm_action = action[7 : 7 + 6]
+ normalized_left_gripper_action = action[6]
+ normalized_right_gripper_action = action[7 + 6]
+
+ left_gripper_action = unnormalize_puppet_gripper_position(normalized_left_gripper_action)
+ right_gripper_action = unnormalize_puppet_gripper_position(normalized_right_gripper_action)
+
+ full_left_gripper_action = [left_gripper_action, -left_gripper_action]
+ full_right_gripper_action = [right_gripper_action, -right_gripper_action]
+
+ env_action = np.concatenate(
+ [left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action]
+ )
+ super().before_step(env_action, physics)
+ return
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ super().initialize_episode(physics)
+
+ @staticmethod
+ def get_qpos(physics):
+ qpos_raw = physics.data.qpos.copy()
+ left_qpos_raw = qpos_raw[:8]
+ right_qpos_raw = qpos_raw[8:16]
+ left_arm_qpos = left_qpos_raw[:6]
+ right_arm_qpos = right_qpos_raw[:6]
+ left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
+ right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
+ return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
+
+ @staticmethod
+ def get_qvel(physics):
+ qvel_raw = physics.data.qvel.copy()
+ left_qvel_raw = qvel_raw[:8]
+ right_qvel_raw = qvel_raw[8:16]
+ left_arm_qvel = left_qvel_raw[:6]
+ right_arm_qvel = right_qvel_raw[:6]
+ left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
+ right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
+ return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
+
+ @staticmethod
+ def get_env_state(physics):
+ raise NotImplementedError
+
+ def get_observation(self, physics):
+ obs = collections.OrderedDict()
+ obs["qpos"] = self.get_qpos(physics)
+ obs["qvel"] = self.get_qvel(physics)
+ obs["env_state"] = self.get_env_state(physics)
+ obs["images"] = {}
+ obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
+ obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
+ obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
+
+ return obs
+
+ def get_reward(self, physics):
+ # return whether left gripper is holding the box
+ raise NotImplementedError
+
+
+class TransferCubeTask(BimanualViperXTask):
+ def __init__(self, random=None):
+ super().__init__(random=random)
+ self.max_reward = 4
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
+ # reset qpos, control and box position
+ with physics.reset_context():
+ physics.named.data.qpos[:16] = START_ARM_POSE
+ np.copyto(physics.data.ctrl, START_ARM_POSE)
+ assert BOX_POSE[0] is not None
+ physics.named.data.qpos[-7:] = BOX_POSE[0]
+ # print(f"{BOX_POSE=}")
+ super().initialize_episode(physics)
+
+ @staticmethod
+ def get_env_state(physics):
+ env_state = physics.data.qpos.copy()[16:]
+ return env_state
+
+ def get_reward(self, physics):
+ # return whether left gripper is holding the box
+ all_contact_pairs = []
+ for i_contact in range(physics.data.ncon):
+ id_geom_1 = physics.data.contact[i_contact].geom1
+ id_geom_2 = physics.data.contact[i_contact].geom2
+ name_geom_1 = physics.model.id2name(id_geom_1, "geom")
+ name_geom_2 = physics.model.id2name(id_geom_2, "geom")
+ contact_pair = (name_geom_1, name_geom_2)
+ all_contact_pairs.append(contact_pair)
+
+ touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
+ touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
+ touch_table = ("red_box", "table") in all_contact_pairs
+
+ reward = 0
+ if touch_right_gripper:
+ reward = 1
+ if touch_right_gripper and not touch_table: # lifted
+ reward = 2
+ if touch_left_gripper: # attempted transfer
+ reward = 3
+ if touch_left_gripper and not touch_table: # successful transfer
+ reward = 4
+ return reward
+
+
+class InsertionTask(BimanualViperXTask):
+ def __init__(self, random=None):
+ super().__init__(random=random)
+ self.max_reward = 4
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
+ # reset qpos, control and box position
+ with physics.reset_context():
+ physics.named.data.qpos[:16] = START_ARM_POSE
+ np.copyto(physics.data.ctrl, START_ARM_POSE)
+ assert BOX_POSE[0] is not None
+ physics.named.data.qpos[-7 * 2 :] = BOX_POSE[0] # two objects
+ # print(f"{BOX_POSE=}")
+ super().initialize_episode(physics)
+
+ @staticmethod
+ def get_env_state(physics):
+ env_state = physics.data.qpos.copy()[16:]
+ return env_state
+
+ def get_reward(self, physics):
+ # return whether peg touches the pin
+ all_contact_pairs = []
+ for i_contact in range(physics.data.ncon):
+ id_geom_1 = physics.data.contact[i_contact].geom1
+ id_geom_2 = physics.data.contact[i_contact].geom2
+ name_geom_1 = physics.model.id2name(id_geom_1, "geom")
+ name_geom_2 = physics.model.id2name(id_geom_2, "geom")
+ contact_pair = (name_geom_1, name_geom_2)
+ all_contact_pairs.append(contact_pair)
+
+ touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
+ touch_left_gripper = (
+ ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
+ or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
+ or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
+ or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
+ )
+
+ peg_touch_table = ("red_peg", "table") in all_contact_pairs
+ socket_touch_table = (
+ ("socket-1", "table") in all_contact_pairs
+ or ("socket-2", "table") in all_contact_pairs
+ or ("socket-3", "table") in all_contact_pairs
+ or ("socket-4", "table") in all_contact_pairs
+ )
+ peg_touch_socket = (
+ ("red_peg", "socket-1") in all_contact_pairs
+ or ("red_peg", "socket-2") in all_contact_pairs
+ or ("red_peg", "socket-3") in all_contact_pairs
+ or ("red_peg", "socket-4") in all_contact_pairs
+ )
+ pin_touched = ("red_peg", "pin") in all_contact_pairs
+
+ reward = 0
+ if touch_left_gripper and touch_right_gripper: # touch both
+ reward = 1
+ if (
+ touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
+ ): # grasp both
+ reward = 2
+ if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
+ reward = 3
+ if pin_touched: # successful insertion
+ reward = 4
+ return reward
diff --git a/lerobot/common/envs/aloha/tasks/sim_end_effector.py b/lerobot/common/envs/aloha/tasks/sim_end_effector.py
new file mode 100644
index 00000000..d93c8330
--- /dev/null
+++ b/lerobot/common/envs/aloha/tasks/sim_end_effector.py
@@ -0,0 +1,263 @@
+import collections
+
+import numpy as np
+from dm_control.suite import base
+
+from lerobot.common.envs.aloha.constants import (
+ PUPPET_GRIPPER_POSITION_CLOSE,
+ START_ARM_POSE,
+ normalize_puppet_gripper_position,
+ normalize_puppet_gripper_velocity,
+ unnormalize_puppet_gripper_position,
+)
+from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
+
+"""
+Environment for simulated robot bi-manual manipulation, with end-effector control.
+Action space: [left_arm_pose (7), # position and quaternion for end effector
+ left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
+ right_arm_pose (7), # position and quaternion for end effector
+ right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
+
+Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
+ left_gripper_position (1), # normalized gripper position (0: close, 1: open)
+ right_arm_qpos (6), # absolute joint position
+ right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
+ "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
+ left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
+ right_arm_qvel (6), # absolute joint velocity (rad)
+ right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
+ "images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
+"""
+
+
+class BimanualViperXEndEffectorTask(base.Task):
+ def __init__(self, random=None):
+ super().__init__(random=random)
+
+ def before_step(self, action, physics):
+ a_len = len(action) // 2
+ action_left = action[:a_len]
+ action_right = action[a_len:]
+
+ # set mocap position and quat
+ # left
+ np.copyto(physics.data.mocap_pos[0], action_left[:3])
+ np.copyto(physics.data.mocap_quat[0], action_left[3:7])
+ # right
+ np.copyto(physics.data.mocap_pos[1], action_right[:3])
+ np.copyto(physics.data.mocap_quat[1], action_right[3:7])
+
+ # set gripper
+ g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7])
+ g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7])
+ np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
+
+ def initialize_robots(self, physics):
+ # reset joint position
+ physics.named.data.qpos[:16] = START_ARM_POSE
+
+ # reset mocap to align with end effector
+ # to obtain these numbers:
+ # (1) make an ee_sim env and reset to the same start_pose
+ # (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
+ # get env._physics.named.data.xquat['vx300s_left/gripper_link']
+ # repeat the same for right side
+ np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
+ np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
+ # right
+ np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
+ np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
+
+ # reset gripper control
+ close_gripper_control = np.array(
+ [
+ PUPPET_GRIPPER_POSITION_CLOSE,
+ -PUPPET_GRIPPER_POSITION_CLOSE,
+ PUPPET_GRIPPER_POSITION_CLOSE,
+ -PUPPET_GRIPPER_POSITION_CLOSE,
+ ]
+ )
+ np.copyto(physics.data.ctrl, close_gripper_control)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ super().initialize_episode(physics)
+
+ @staticmethod
+ def get_qpos(physics):
+ qpos_raw = physics.data.qpos.copy()
+ left_qpos_raw = qpos_raw[:8]
+ right_qpos_raw = qpos_raw[8:16]
+ left_arm_qpos = left_qpos_raw[:6]
+ right_arm_qpos = right_qpos_raw[:6]
+ left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
+ right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
+ return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
+
+ @staticmethod
+ def get_qvel(physics):
+ qvel_raw = physics.data.qvel.copy()
+ left_qvel_raw = qvel_raw[:8]
+ right_qvel_raw = qvel_raw[8:16]
+ left_arm_qvel = left_qvel_raw[:6]
+ right_arm_qvel = right_qvel_raw[:6]
+ left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
+ right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
+ return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
+
+ @staticmethod
+ def get_env_state(physics):
+ raise NotImplementedError
+
+ def get_observation(self, physics):
+ # note: it is important to do .copy()
+ obs = collections.OrderedDict()
+ obs["qpos"] = self.get_qpos(physics)
+ obs["qvel"] = self.get_qvel(physics)
+ obs["env_state"] = self.get_env_state(physics)
+ obs["images"] = {}
+ obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
+ obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
+ obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
+ # used in scripted policy to obtain starting pose
+ obs["mocap_pose_left"] = np.concatenate(
+ [physics.data.mocap_pos[0], physics.data.mocap_quat[0]]
+ ).copy()
+ obs["mocap_pose_right"] = np.concatenate(
+ [physics.data.mocap_pos[1], physics.data.mocap_quat[1]]
+ ).copy()
+
+ # used when replaying joint trajectory
+ obs["gripper_ctrl"] = physics.data.ctrl.copy()
+ return obs
+
+ def get_reward(self, physics):
+ raise NotImplementedError
+
+
+class TransferCubeEndEffectorTask(BimanualViperXEndEffectorTask):
+ def __init__(self, random=None):
+ super().__init__(random=random)
+ self.max_reward = 4
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ self.initialize_robots(physics)
+ # randomize box position
+ cube_pose = sample_box_pose()
+ box_start_idx = physics.model.name2id("red_box_joint", "joint")
+ np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
+ # print(f"randomized cube position to {cube_position}")
+
+ super().initialize_episode(physics)
+
+ @staticmethod
+ def get_env_state(physics):
+ env_state = physics.data.qpos.copy()[16:]
+ return env_state
+
+ def get_reward(self, physics):
+ # return whether left gripper is holding the box
+ all_contact_pairs = []
+ for i_contact in range(physics.data.ncon):
+ id_geom_1 = physics.data.contact[i_contact].geom1
+ id_geom_2 = physics.data.contact[i_contact].geom2
+ name_geom_1 = physics.model.id2name(id_geom_1, "geom")
+ name_geom_2 = physics.model.id2name(id_geom_2, "geom")
+ contact_pair = (name_geom_1, name_geom_2)
+ all_contact_pairs.append(contact_pair)
+
+ touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
+ touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
+ touch_table = ("red_box", "table") in all_contact_pairs
+
+ reward = 0
+ if touch_right_gripper:
+ reward = 1
+ if touch_right_gripper and not touch_table: # lifted
+ reward = 2
+ if touch_left_gripper: # attempted transfer
+ reward = 3
+ if touch_left_gripper and not touch_table: # successful transfer
+ reward = 4
+ return reward
+
+
+class InsertionEndEffectorTask(BimanualViperXEndEffectorTask):
+ def __init__(self, random=None):
+ super().__init__(random=random)
+ self.max_reward = 4
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ self.initialize_robots(physics)
+ # randomize peg and socket position
+ peg_pose, socket_pose = sample_insertion_pose()
+
+ def id2index(j_id):
+ return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
+
+ peg_start_id = physics.model.name2id("red_peg_joint", "joint")
+ peg_start_idx = id2index(peg_start_id)
+ np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
+ # print(f"randomized cube position to {cube_position}")
+
+ socket_start_id = physics.model.name2id("blue_socket_joint", "joint")
+ socket_start_idx = id2index(socket_start_id)
+ np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
+ # print(f"randomized cube position to {cube_position}")
+
+ super().initialize_episode(physics)
+
+ @staticmethod
+ def get_env_state(physics):
+ env_state = physics.data.qpos.copy()[16:]
+ return env_state
+
+ def get_reward(self, physics):
+ # return whether peg touches the pin
+ all_contact_pairs = []
+ for i_contact in range(physics.data.ncon):
+ id_geom_1 = physics.data.contact[i_contact].geom1
+ id_geom_2 = physics.data.contact[i_contact].geom2
+ name_geom_1 = physics.model.id2name(id_geom_1, "geom")
+ name_geom_2 = physics.model.id2name(id_geom_2, "geom")
+ contact_pair = (name_geom_1, name_geom_2)
+ all_contact_pairs.append(contact_pair)
+
+ touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
+ touch_left_gripper = (
+ ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
+ or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
+ or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
+ or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
+ )
+
+ peg_touch_table = ("red_peg", "table") in all_contact_pairs
+ socket_touch_table = (
+ ("socket-1", "table") in all_contact_pairs
+ or ("socket-2", "table") in all_contact_pairs
+ or ("socket-3", "table") in all_contact_pairs
+ or ("socket-4", "table") in all_contact_pairs
+ )
+ peg_touch_socket = (
+ ("red_peg", "socket-1") in all_contact_pairs
+ or ("red_peg", "socket-2") in all_contact_pairs
+ or ("red_peg", "socket-3") in all_contact_pairs
+ or ("red_peg", "socket-4") in all_contact_pairs
+ )
+ pin_touched = ("red_peg", "pin") in all_contact_pairs
+
+ reward = 0
+ if touch_left_gripper and touch_right_gripper: # touch both
+ reward = 1
+ if (
+ touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
+ ): # grasp both
+ reward = 2
+ if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
+ reward = 3
+ if pin_touched: # successful insertion
+ reward = 4
+ return reward
diff --git a/lerobot/common/envs/aloha/utils.py b/lerobot/common/envs/aloha/utils.py
new file mode 100644
index 00000000..5ac8b955
--- /dev/null
+++ b/lerobot/common/envs/aloha/utils.py
@@ -0,0 +1,39 @@
+import numpy as np
+
+
+def sample_box_pose():
+ x_range = [0.0, 0.2]
+ y_range = [0.4, 0.6]
+ z_range = [0.05, 0.05]
+
+ ranges = np.vstack([x_range, y_range, z_range])
+ cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
+
+ cube_quat = np.array([1, 0, 0, 0])
+ return np.concatenate([cube_position, cube_quat])
+
+
+def sample_insertion_pose():
+ # Peg
+ x_range = [0.1, 0.2]
+ y_range = [0.4, 0.6]
+ z_range = [0.05, 0.05]
+
+ ranges = np.vstack([x_range, y_range, z_range])
+ peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
+
+ peg_quat = np.array([1, 0, 0, 0])
+ peg_pose = np.concatenate([peg_position, peg_quat])
+
+ # Socket
+ x_range = [-0.2, -0.1]
+ y_range = [0.4, 0.6]
+ z_range = [0.05, 0.05]
+
+ ranges = np.vstack([x_range, y_range, z_range])
+ socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
+
+ socket_quat = np.array([1, 0, 0, 0])
+ socket_pose = np.concatenate([socket_position, socket_quat])
+
+ return peg_pose, socket_pose
diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py
index 269009db..921cbad7 100644
--- a/lerobot/common/envs/factory.py
+++ b/lerobot/common/envs/factory.py
@@ -23,6 +23,11 @@ def make_env(cfg, transform=None):
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
clsfunc = PushtEnv
+ elif cfg.env.name == "aloha":
+ from lerobot.common.envs.aloha.env import AlohaEnv
+
+ kwargs["task"] = cfg.env.task
+ clsfunc = AlohaEnv
else:
raise ValueError(cfg.env.name)
diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py
index ff49f791..4a7ccb2c 100644
--- a/lerobot/common/envs/pusht/env.py
+++ b/lerobot/common/envs/pusht/env.py
@@ -11,61 +11,52 @@ from torchrl.data.tensor_specs import (
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
-from torchrl.envs import EnvBase
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
+from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.utils import set_seed
_has_gym = importlib.util.find_spec("gym") is not None
-class PushtEnv(EnvBase):
+class PushtEnv(AbstractEnv):
def __init__(
self,
+ task="pusht",
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
- num_prev_obs=0,
+ num_prev_obs=1,
num_prev_action=0,
):
- super().__init__(device=device, batch_size=[])
- self.frame_skip = frame_skip
- self.from_pixels = from_pixels
- self.pixels_only = pixels_only
- self.image_size = image_size
- self.num_prev_obs = num_prev_obs
- self.num_prev_action = num_prev_action
-
- if pixels_only:
- assert from_pixels
- if from_pixels:
- assert image_size
+ super().__init__(
+ task=task,
+ frame_skip=frame_skip,
+ from_pixels=from_pixels,
+ pixels_only=pixels_only,
+ image_size=image_size,
+ seed=seed,
+ device=device,
+ num_prev_obs=num_prev_obs,
+ num_prev_action=num_prev_action,
+ )
+ def _make_env(self):
if not _has_gym:
raise ImportError("Cannot import gym.")
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
# from lerobot.common.envs.pusht.pusht_env import PushTEnv
- if not from_pixels:
+ if not self.from_pixels:
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
from lerobot.common.envs.pusht.pusht_image_env import PushTImageEnv
self._env = PushTImageEnv(render_size=self.image_size)
- self._make_spec()
- self._current_seed = self.set_seed(seed)
-
- if self.num_prev_obs > 0:
- self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
- self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
- if self.num_prev_action > 0:
- raise NotImplementedError()
- # self._prev_action_queue = deque(maxlen=self.num_prev_action)
-
def render(self, mode="rgb_array", width=384, height=384):
if width != height:
raise NotImplementedError()
@@ -122,6 +113,8 @@ class PushtEnv(EnvBase):
)
else:
raise NotImplementedError()
+
+ self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
@@ -154,6 +147,8 @@ class PushtEnv(EnvBase):
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
+ self.call_rendering_hooks()
+
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
@@ -175,9 +170,9 @@ class PushtEnv(EnvBase):
obs["image"] = BoundedTensorSpec(
low=0,
- high=1,
+ high=255,
shape=image_shape,
- dtype=torch.float32,
+ dtype=torch.uint8,
device=self.device,
)
if not self.pixels_only:
diff --git a/lerobot/common/envs/pusht/pusht_image_env.py b/lerobot/common/envs/pusht/pusht_image_env.py
index 5f7bc03c..0807e849 100644
--- a/lerobot/common/envs/pusht/pusht_image_env.py
+++ b/lerobot/common/envs/pusht/pusht_image_env.py
@@ -25,7 +25,7 @@ class PushTImageEnv(PushTEnv):
img = super()._render_frame(mode="rgb_array")
agent_pos = np.array(self.agent.position)
- img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0)
+ img_obs = np.moveaxis(img, -1, 0)
obs = {"image": img_obs, "agent_pos": agent_pos}
# draw action
diff --git a/lerobot/common/envs/simxarm.py b/lerobot/common/envs/simxarm.py
index 24fd9ba4..d0612625 100644
--- a/lerobot/common/envs/simxarm.py
+++ b/lerobot/common/envs/simxarm.py
@@ -1,6 +1,8 @@
import importlib
+from collections import deque
from typing import Optional
+import einops
import numpy as np
import torch
from tensordict import TensorDict
@@ -10,9 +12,9 @@ from torchrl.data.tensor_specs import (
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
-from torchrl.envs import EnvBase
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
+from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.utils import set_seed
MAX_NUM_ACTIONS = 4
@@ -21,7 +23,7 @@ _has_gym = importlib.util.find_spec("gym") is not None
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
-class SimxarmEnv(EnvBase):
+class SimxarmEnv(AbstractEnv):
def __init__(
self,
task,
@@ -31,19 +33,22 @@ class SimxarmEnv(EnvBase):
image_size=None,
seed=1337,
device="cpu",
+ num_prev_obs=0,
+ num_prev_action=0,
):
- super().__init__(device=device, batch_size=[])
- self.task = task
- self.frame_skip = frame_skip
- self.from_pixels = from_pixels
- self.pixels_only = pixels_only
- self.image_size = image_size
-
- if pixels_only:
- assert from_pixels
- if from_pixels:
- assert image_size
+ super().__init__(
+ task=task,
+ frame_skip=frame_skip,
+ from_pixels=from_pixels,
+ pixels_only=pixels_only,
+ image_size=image_size,
+ seed=seed,
+ device=device,
+ num_prev_obs=num_prev_obs,
+ num_prev_action=num_prev_action,
+ )
+ def _make_env(self):
if not _has_simxarm:
raise ImportError("Cannot import simxarm.")
if not _has_gym:
@@ -63,9 +68,6 @@ class SimxarmEnv(EnvBase):
if "w" not in TASKS[self.task]["action_space"]:
self._action_padding[-1] = 1.0
- self._make_spec()
- self.set_seed(seed)
-
def render(self, mode="rgb_array", width=384, height=384):
return self._env.render(mode, width=width, height=height)
@@ -90,15 +92,33 @@ class SimxarmEnv(EnvBase):
if td is None or td.is_empty():
raw_obs = self._env.reset()
+ obs = self._format_raw_obs(raw_obs)
+
+ if self.num_prev_obs > 0:
+ stacked_obs = {}
+ if "image" in obs:
+ self._prev_obs_image_queue = deque(
+ [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
+ )
+ stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
+ if "state" in obs:
+ self._prev_obs_state_queue = deque(
+ [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
+ )
+ stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
+ obs = stacked_obs
+
td = TensorDict(
{
- "observation": self._format_raw_obs(raw_obs),
+ "observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
+
+ self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
@@ -108,10 +128,32 @@ class SimxarmEnv(EnvBase):
action = np.concatenate([action, self._action_padding])
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
- for _ in range(self.frame_skip):
- raw_obs, reward, done, info = self._env.step(action)
+
+ if action.ndim == 1:
+ action = einops.repeat(action, "c -> t c", t=self.frame_skip)
+ else:
+ if self.frame_skip > 1:
+ raise NotImplementedError()
+
+ num_action_steps = action.shape[0]
+ for i in range(num_action_steps):
+ raw_obs, reward, done, info = self._env.step(action[i])
sum_reward += reward
+ obs = self._format_raw_obs(raw_obs)
+
+ if self.num_prev_obs > 0:
+ stacked_obs = {}
+ if "image" in obs:
+ self._prev_obs_image_queue.append(obs["image"])
+ stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
+ if "state" in obs:
+ self._prev_obs_state_queue.append(obs["state"])
+ stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
+ obs = stacked_obs
+
+ self.call_rendering_hooks()
+
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
@@ -126,23 +168,36 @@ class SimxarmEnv(EnvBase):
def _make_spec(self):
obs = {}
if self.from_pixels:
+ image_shape = (3, self.image_size, self.image_size)
+ if self.num_prev_obs > 0:
+ image_shape = (self.num_prev_obs + 1, *image_shape)
+
obs["image"] = BoundedTensorSpec(
low=0,
high=255,
- shape=(3, self.image_size, self.image_size),
+ shape=image_shape,
dtype=torch.uint8,
device=self.device,
)
if not self.pixels_only:
+ state_shape = (len(self._env.robot_state),)
+ if self.num_prev_obs > 0:
+ state_shape = (self.num_prev_obs + 1, *state_shape)
+
obs["state"] = UnboundedContinuousTensorSpec(
- shape=(len(self._env.robot_state),),
+ shape=state_shape,
dtype=torch.float32,
device=self.device,
)
else:
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
+ state_shape = self._env.observation_space["observation"].shape
+ if self.num_prev_obs > 0:
+ state_shape = (self.num_prev_obs + 1, *state_shape)
+
obs["state"] = UnboundedContinuousTensorSpec(
- shape=self._env.observation_space["observation"].shape,
+ # TODO:
+ shape=state_shape,
dtype=torch.float32,
device=self.device,
)
diff --git a/lerobot/common/envs/transforms.py b/lerobot/common/envs/transforms.py
index 671c0827..4832c91b 100644
--- a/lerobot/common/envs/transforms.py
+++ b/lerobot/common/envs/transforms.py
@@ -1,5 +1,6 @@
from typing import Sequence
+import torch
from tensordict import TensorDictBase
from tensordict.nn import dispatch
from tensordict.utils import NestedKey
@@ -7,19 +8,45 @@ from torchrl.envs.transforms import ObservationTransform, Transform
class Prod(ObservationTransform):
+ invertible = True
+
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
super().__init__()
self.in_keys = in_keys
self.prod = prod
+ self.original_dtypes = {}
+
+ def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
+ # _reset is called once when the environment reset to normalize the first observation
+ tensordict_reset = self._call(tensordict_reset)
+ return tensordict_reset
+
+ @dispatch(source="in_keys", dest="out_keys")
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
+ return self._call(tensordict)
def _call(self, td):
for key in self.in_keys:
- td[key] *= self.prod
+ if td.get(key, None) is None:
+ continue
+ self.original_dtypes[key] = td[key].dtype
+ td[key] = td[key].type(torch.float32) * self.prod
+ return td
+
+ def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
+ for key in self.in_keys:
+ if td.get(key, None) is None:
+ continue
+ td[key] = (td[key] / self.prod).type(self.original_dtypes[key])
return td
def transform_observation_spec(self, obs_spec):
for key in self.in_keys:
- obs_spec[key].space.high *= self.prod
+ if obs_spec.get(key, None) is None:
+ continue
+ obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
+ obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
+ obs_spec[key].dtype = torch.float32
return obs_spec
diff --git a/lerobot/common/policies/act/backbone.py b/lerobot/common/policies/act/backbone.py
new file mode 100644
index 00000000..6399d339
--- /dev/null
+++ b/lerobot/common/policies/act/backbone.py
@@ -0,0 +1,115 @@
+from typing import List
+
+import torch
+import torchvision
+from torch import nn
+from torchvision.models._utils import IntermediateLayerGetter
+
+from .position_encoding import build_position_encoding
+from .utils import NestedTensor, is_main_process
+
+
+class FrozenBatchNorm2d(torch.nn.Module):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
+ without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
+ produce nans.
+ """
+
+ def __init__(self, n):
+ super().__init__()
+ self.register_buffer("weight", torch.ones(n))
+ self.register_buffer("bias", torch.zeros(n))
+ self.register_buffer("running_mean", torch.zeros(n))
+ self.register_buffer("running_var", torch.ones(n))
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ num_batches_tracked_key = prefix + "num_batches_tracked"
+ if num_batches_tracked_key in state_dict:
+ del state_dict[num_batches_tracked_key]
+
+ super()._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ )
+
+ def forward(self, x):
+ # move reshapes to the beginning
+ # to make it fuser-friendly
+ w = self.weight.reshape(1, -1, 1, 1)
+ b = self.bias.reshape(1, -1, 1, 1)
+ rv = self.running_var.reshape(1, -1, 1, 1)
+ rm = self.running_mean.reshape(1, -1, 1, 1)
+ eps = 1e-5
+ scale = w * (rv + eps).rsqrt()
+ bias = b - rm * scale
+ return x * scale + bias
+
+
+class BackboneBase(nn.Module):
+ def __init__(
+ self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool
+ ):
+ super().__init__()
+ # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
+ # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+ # parameter.requires_grad_(False)
+ if return_interm_layers:
+ return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
+ else:
+ return_layers = {"layer4": "0"}
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+ self.num_channels = num_channels
+
+ def forward(self, tensor):
+ xs = self.body(tensor)
+ return xs
+ # out: Dict[str, NestedTensor] = {}
+ # for name, x in xs.items():
+ # m = tensor_list.mask
+ # assert m is not None
+ # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
+ # out[name] = NestedTensor(x, mask)
+ # return out
+
+
+class Backbone(BackboneBase):
+ """ResNet backbone with frozen BatchNorm."""
+
+ def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
+ backbone = getattr(torchvision.models, name)(
+ replace_stride_with_dilation=[False, False, dilation],
+ pretrained=is_main_process(),
+ norm_layer=FrozenBatchNorm2d,
+ ) # pretrained # TODO do we want frozen batch_norm??
+ num_channels = 512 if name in ("resnet18", "resnet34") else 2048
+ super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
+
+
+class Joiner(nn.Sequential):
+ def __init__(self, backbone, position_embedding):
+ super().__init__(backbone, position_embedding)
+
+ def forward(self, tensor_list: NestedTensor):
+ xs = self[0](tensor_list)
+ out: List[NestedTensor] = []
+ pos = []
+ for _, x in xs.items():
+ out.append(x)
+ # position encoding
+ pos.append(self[1](x).to(x.dtype))
+
+ return out, pos
+
+
+def build_backbone(args):
+ position_embedding = build_position_encoding(args)
+ train_backbone = args.lr_backbone > 0
+ return_interm_layers = args.masks
+ backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
+ model = Joiner(backbone, position_embedding)
+ model.num_channels = backbone.num_channels
+ return model
diff --git a/lerobot/common/policies/act/detr_vae.py b/lerobot/common/policies/act/detr_vae.py
new file mode 100644
index 00000000..0f2626f7
--- /dev/null
+++ b/lerobot/common/policies/act/detr_vae.py
@@ -0,0 +1,212 @@
+import numpy as np
+import torch
+from torch import nn
+from torch.autograd import Variable
+
+from .backbone import build_backbone
+from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer
+
+
+def reparametrize(mu, logvar):
+ std = logvar.div(2).exp()
+ eps = Variable(std.data.new(std.size()).normal_())
+ return mu + std * eps
+
+
+def get_sinusoid_encoding_table(n_position, d_hid):
+ def get_position_angle_vec(position):
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
+
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
+
+
+class DETRVAE(nn.Module):
+ """This is the DETR module that performs object detection"""
+
+ def __init__(
+ self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae
+ ):
+ """Initializes the model.
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.camera_names = camera_names
+ self.transformer = transformer
+ self.encoder = encoder
+ self.vae = vae
+ hidden_dim = transformer.d_model
+ self.action_head = nn.Linear(hidden_dim, action_dim)
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+ if backbones is not None:
+ self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
+ self.backbones = nn.ModuleList(backbones)
+ self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
+ else:
+ # input_dim = 14 + 7 # robot_state + env_state
+ self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
+ # TODO(rcadene): understand what is env_state, and why it needs to be 7
+ self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim)
+ self.pos = torch.nn.Embedding(2, hidden_dim)
+ self.backbones = None
+
+ # encoder extra parameters
+ self.latent_dim = 32 # final size of latent z # TODO tune
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
+ self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
+ self.latent_proj = nn.Linear(
+ hidden_dim, self.latent_dim * 2
+ ) # project hidden state to latent std, var
+ self.register_buffer(
+ "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
+ ) # [CLS], qpos, a_seq
+
+ # decoder extra parameters
+ self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
+ self.additional_pos_embed = nn.Embedding(
+ 2, hidden_dim
+ ) # learned position embedding for proprio and latent
+
+ def forward(self, qpos, image, env_state, actions=None, is_pad=None):
+ """
+ qpos: batch, qpos_dim
+ image: batch, num_cam, channel, height, width
+ env_state: None
+ actions: batch, seq, action_dim
+ """
+ is_training = actions is not None # train or val
+ bs, _ = qpos.shape
+ ### Obtain latent z from action sequence
+ if self.vae and is_training:
+ # project action sequence to embedding dim, and concat with a CLS token
+ action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
+ qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
+ qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
+ cls_embed = self.cls_embed.weight # (1, hidden_dim)
+ cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
+ encoder_input = torch.cat(
+ [cls_embed, qpos_embed, action_embed], axis=1
+ ) # (bs, seq+1, hidden_dim)
+ encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
+ # do not mask cls token
+ # cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
+ # is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
+ # obtain position embedding
+ pos_embed = self.pos_table.clone().detach()
+ pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
+ # query model
+ encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad)
+ encoder_output = encoder_output[0] # take cls output only
+ latent_info = self.latent_proj(encoder_output)
+ mu = latent_info[:, : self.latent_dim]
+ logvar = latent_info[:, self.latent_dim :]
+ latent_sample = reparametrize(mu, logvar)
+ latent_input = self.latent_out_proj(latent_sample)
+ else:
+ mu = logvar = None
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
+ latent_input = self.latent_out_proj(latent_sample)
+
+ if self.backbones is not None:
+ # Image observation features and position embeddings
+ all_cam_features = []
+ all_cam_pos = []
+ for cam_id, _ in enumerate(self.camera_names):
+ features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
+ features = features[0] # take the last layer feature
+ pos = pos[0]
+ all_cam_features.append(self.input_proj(features))
+ all_cam_pos.append(pos)
+ # proprioception features
+ proprio_input = self.input_proj_robot_state(qpos)
+ # fold camera dimension into width dimension
+ src = torch.cat(all_cam_features, axis=3)
+ pos = torch.cat(all_cam_pos, axis=3)
+ hs = self.transformer(
+ src,
+ None,
+ self.query_embed.weight,
+ pos,
+ latent_input,
+ proprio_input,
+ self.additional_pos_embed.weight,
+ )[0]
+ else:
+ qpos = self.input_proj_robot_state(qpos)
+ env_state = self.input_proj_env_state(env_state)
+ transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
+ hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
+ a_hat = self.action_head(hs)
+ is_pad_hat = self.is_pad_head(hs)
+ return a_hat, is_pad_hat, [mu, logvar]
+
+
+def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
+ if hidden_depth == 0:
+ mods = [nn.Linear(input_dim, output_dim)]
+ else:
+ mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
+ for _ in range(hidden_depth - 1):
+ mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
+ mods.append(nn.Linear(hidden_dim, output_dim))
+ trunk = nn.Sequential(*mods)
+ return trunk
+
+
+def build_encoder(args):
+ d_model = args.hidden_dim # 256
+ dropout = args.dropout # 0.1
+ nhead = args.nheads # 8
+ dim_feedforward = args.dim_feedforward # 2048
+ num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
+ normalize_before = args.pre_norm # False
+ activation = "relu"
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ return encoder
+
+
+def build(args):
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ backbones = []
+ backbone = build_backbone(args)
+ backbones.append(backbone)
+
+ transformer = build_transformer(args)
+
+ encoder = build_encoder(args)
+
+ model = DETRVAE(
+ backbones,
+ transformer,
+ encoder,
+ state_dim=args.state_dim,
+ action_dim=args.action_dim,
+ num_queries=args.num_queries,
+ camera_names=args.camera_names,
+ vae=args.vae,
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
+
+ return model
diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py
new file mode 100644
index 00000000..d011cb76
--- /dev/null
+++ b/lerobot/common/policies/act/policy.py
@@ -0,0 +1,217 @@
+import logging
+import time
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F # noqa: N812
+import torchvision.transforms as transforms
+
+from lerobot.common.policies.act.detr_vae import build
+
+
+def build_act_model_and_optimizer(cfg):
+ model = build(cfg)
+
+ param_dicts = [
+ {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
+ {
+ "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
+ "lr": cfg.lr_backbone,
+ },
+ ]
+ optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
+
+ return model, optimizer
+
+
+def kl_divergence(mu, logvar):
+ batch_size = mu.size(0)
+ assert batch_size != 0
+ if mu.data.ndimension() == 4:
+ mu = mu.view(mu.size(0), mu.size(1))
+ if logvar.data.ndimension() == 4:
+ logvar = logvar.view(logvar.size(0), logvar.size(1))
+
+ klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
+ total_kld = klds.sum(1).mean(0, True)
+ dimension_wise_kld = klds.mean(0)
+ mean_kld = klds.mean(1).mean(0, True)
+
+ return total_kld, dimension_wise_kld, mean_kld
+
+
+class ActionChunkingTransformerPolicy(nn.Module):
+ def __init__(self, cfg, device, n_action_steps=1):
+ super().__init__()
+ self.cfg = cfg
+ self.n_action_steps = n_action_steps
+ self.device = device
+ self.model, self.optimizer = build_act_model_and_optimizer(cfg)
+ self.kl_weight = self.cfg.kl_weight
+ logging.info(f"KL Weight {self.kl_weight}")
+ self.to(self.device)
+
+ def update(self, replay_buffer, step):
+ del step
+
+ start_time = time.time()
+
+ self.train()
+
+ num_slices = self.cfg.batch_size
+ batch_size = self.cfg.horizon * num_slices
+
+ assert batch_size % self.cfg.horizon == 0
+ assert batch_size % num_slices == 0
+
+ def process_batch(batch, horizon, num_slices):
+ # trajectory t = 64, horizon h = 16
+ # (t h) ... -> t h ...
+ batch = batch.reshape(num_slices, horizon)
+
+ image = batch["observation", "image", "top"]
+ image = image[:, 0] # first observation t=0
+ # batch, num_cam, channel, height, width
+ image = image.unsqueeze(1)
+ assert image.ndim == 5
+ image = image.float()
+
+ state = batch["observation", "state"]
+ state = state[:, 0] # first observation t=0
+ # batch, qpos_dim
+ assert state.ndim == 2
+
+ action = batch["action"]
+ # batch, seq, action_dim
+ assert action.ndim == 3
+ assert action.shape[1] == horizon
+
+ if self.cfg.n_obs_steps > 1:
+ raise NotImplementedError()
+ # # keep first n observations of the slice corresponding to t=[-1,0]
+ # image = image[:, : self.cfg.n_obs_steps]
+ # state = state[:, : self.cfg.n_obs_steps]
+
+ out = {
+ "obs": {
+ "image": image.to(self.device, non_blocking=True),
+ "agent_pos": state.to(self.device, non_blocking=True),
+ },
+ "action": action.to(self.device, non_blocking=True),
+ }
+ return out
+
+ batch = replay_buffer.sample(batch_size)
+ batch = process_batch(batch, self.cfg.horizon, num_slices)
+
+ data_s = time.time() - start_time
+
+ loss = self.compute_loss(batch)
+ loss.backward()
+
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(),
+ self.cfg.grad_clip_norm,
+ error_if_nonfinite=False,
+ )
+
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ # self.lr_scheduler.step()
+
+ info = {
+ "loss": loss.item(),
+ "grad_norm": float(grad_norm),
+ # "lr": self.lr_scheduler.get_last_lr()[0],
+ "lr": self.cfg.lr,
+ "data_s": data_s,
+ "update_s": time.time() - start_time,
+ }
+
+ return info
+
+ def save(self, fp):
+ torch.save(self.state_dict(), fp)
+
+ def load(self, fp):
+ d = torch.load(fp)
+ self.load_state_dict(d)
+
+ def compute_loss(self, batch):
+ loss_dict = self._forward(
+ qpos=batch["obs"]["agent_pos"],
+ image=batch["obs"]["image"],
+ actions=batch["action"],
+ )
+ loss = loss_dict["loss"]
+ return loss
+
+ @torch.no_grad()
+ def forward(self, observation, step_count):
+ # TODO(rcadene): remove unused step_count
+ del step_count
+
+ self.eval()
+
+ # TODO(rcadene): remove unsqueeze hack to add bsize=1
+ observation["image", "top"] = observation["image", "top"].unsqueeze(0)
+ # observation["state"] = observation["state"].unsqueeze(0)
+
+ # TODO(rcadene): remove hack
+ # add 1 camera dimension
+ observation["image", "top"] = observation["image", "top"].unsqueeze(1)
+
+ obs_dict = {
+ "image": observation["image", "top"],
+ "agent_pos": observation["state"],
+ }
+ action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
+
+ if self.cfg.temporal_agg:
+ # TODO(rcadene): implement temporal aggregation
+ raise NotImplementedError()
+ # all_time_actions[[t], t:t+num_queries] = action
+ # actions_for_curr_step = all_time_actions[:, t]
+ # actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
+ # actions_for_curr_step = actions_for_curr_step[actions_populated]
+ # k = 0.01
+ # exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
+ # exp_weights = exp_weights / exp_weights.sum()
+ # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
+ # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
+
+ # remove bsize=1
+ action = action.squeeze(0)
+
+ # take first predicted action or n first actions
+ action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps]
+ return action
+
+ def _forward(self, qpos, image, actions=None, is_pad=None):
+ env_state = None
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ image = normalize(image)
+
+ is_training = actions is not None
+ if is_training: # training time
+ actions = actions[:, : self.model.num_queries]
+ if is_pad is not None:
+ is_pad = is_pad[:, : self.model.num_queries]
+
+ a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
+
+ all_l1 = F.l1_loss(actions, a_hat, reduction="none")
+ l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
+
+ loss_dict = {}
+ loss_dict["l1"] = l1
+ if self.cfg.vae:
+ total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
+ loss_dict["kl"] = total_kld[0]
+ loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
+ else:
+ loss_dict["loss"] = loss_dict["l1"]
+ return loss_dict
+ else:
+ action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
+ return action
diff --git a/lerobot/common/policies/act/position_encoding.py b/lerobot/common/policies/act/position_encoding.py
new file mode 100644
index 00000000..94e862f6
--- /dev/null
+++ b/lerobot/common/policies/act/position_encoding.py
@@ -0,0 +1,101 @@
+"""
+Various positional encodings for the transformer.
+"""
+import math
+
+import torch
+from torch import nn
+
+from .utils import NestedTensor
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, tensor):
+ x = tensor
+ # mask = tensor_list.mask
+ # assert mask is not None
+ # not_mask = ~mask
+
+ not_mask = torch.ones_like(x[0, [0]])
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class PositionEmbeddingLearned(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+
+ def __init__(self, num_pos_feats=256):
+ super().__init__()
+ self.row_embed = nn.Embedding(50, num_pos_feats)
+ self.col_embed = nn.Embedding(50, num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.uniform_(self.row_embed.weight)
+ nn.init.uniform_(self.col_embed.weight)
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ h, w = x.shape[-2:]
+ i = torch.arange(w, device=x.device)
+ j = torch.arange(h, device=x.device)
+ x_emb = self.col_embed(i)
+ y_emb = self.row_embed(j)
+ pos = (
+ torch.cat(
+ [
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
+ y_emb.unsqueeze(1).repeat(1, w, 1),
+ ],
+ dim=-1,
+ )
+ .permute(2, 0, 1)
+ .unsqueeze(0)
+ .repeat(x.shape[0], 1, 1, 1)
+ )
+ return pos
+
+
+def build_position_encoding(args):
+ n_steps = args.hidden_dim // 2
+ if args.position_embedding in ("v2", "sine"):
+ # TODO find a better way of exposing other arguments
+ position_embedding = PositionEmbeddingSine(n_steps, normalize=True)
+ elif args.position_embedding in ("v3", "learned"):
+ position_embedding = PositionEmbeddingLearned(n_steps)
+ else:
+ raise ValueError(f"not supported {args.position_embedding}")
+
+ return position_embedding
diff --git a/lerobot/common/policies/act/transformer.py b/lerobot/common/policies/act/transformer.py
new file mode 100644
index 00000000..b2bd3685
--- /dev/null
+++ b/lerobot/common/policies/act/transformer.py
@@ -0,0 +1,370 @@
+"""
+DETR Transformer class.
+
+Copy-paste from torch.nn.Transformer with modifications:
+ * positional encodings are passed in MHattention
+ * extra LN at the end of encoder is removed
+ * decoder returns a stack of activations from all decoding layers
+"""
+import copy
+from typing import Optional
+
+import torch
+import torch.nn.functional as F # noqa: N812
+from torch import Tensor, nn
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(
+ decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec
+ )
+
+ self._reset_parameters()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(
+ self,
+ src,
+ mask,
+ query_embed,
+ pos_embed,
+ latent_input=None,
+ proprio_input=None,
+ additional_pos_embed=None,
+ ):
+ # TODO flatten only when input has H and W
+ if len(src.shape) == 4: # has H and W
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+ # mask = mask.flatten(1)
+
+ additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
+ pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
+
+ addition_input = torch.stack([latent_input, proprio_input], axis=0)
+ src = torch.cat([addition_input, src], axis=0)
+ else:
+ assert len(src.shape) == 3
+ # flatten NxHWxC to HWxNxC
+ bs, hw, c = src.shape
+ src = src.permute(1, 0, 2)
+ pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+
+ tgt = torch.zeros_like(query_embed)
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+ hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
+ hs = hs.transpose(1, 2)
+ return hs
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(
+ self,
+ src,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ output = src
+
+ for layer in self.layers:
+ output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class TransformerDecoder(nn.Module):
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ output = tgt
+
+ intermediate = []
+
+ for layer in self.layers:
+ output = layer(
+ output,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos,
+ query_pos=query_pos,
+ )
+ if self.return_intermediate:
+ intermediate.append(self.norm(output))
+
+ if self.norm is not None:
+ output = self.norm(output)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output.unsqueeze(0)
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ q = k = self.with_pos_embed(src, pos)
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward_pre(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ src2 = self.norm1(src)
+ q = k = self.with_pos_embed(src2, pos)
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src2 = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+ src = src + self.dropout2(src2)
+ return src
+
+ def forward(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ if self.normalize_before:
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+class TransformerDecoderLayer(nn.Module):
+ def __init__(
+ self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward_pre(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ if self.normalize_before:
+ return self.forward_pre(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_mask,
+ tgt_key_padding_mask,
+ memory_key_padding_mask,
+ pos,
+ query_pos,
+ )
+ return self.forward_post(
+ tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
+ )
+
+
+def _get_clones(module, n):
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
+
+
+def build_transformer(args):
+ return Transformer(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ )
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
diff --git a/lerobot/common/policies/act/utils.py b/lerobot/common/policies/act/utils.py
new file mode 100644
index 00000000..2ce92094
--- /dev/null
+++ b/lerobot/common/policies/act/utils.py
@@ -0,0 +1,477 @@
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+import datetime
+import os
+import pickle
+import subprocess
+import time
+from collections import defaultdict, deque
+from typing import List, Optional
+
+import torch
+import torch.distributed as dist
+
+# needed due to empty tensor bug in pytorch and torchvision 0.5
+import torchvision
+from packaging import version
+from torch import Tensor
+
+if version.parse(torchvision.__version__) < version.parse("0.7"):
+ from torchvision.ops import _new_empty_tensor
+ from torchvision.ops.misc import _output_size
+
+
+class SmoothedValue:
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
+ )
+
+
+def all_gather(data):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ # serialized to a Tensor
+ buffer = pickle.dumps(data)
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to("cuda")
+
+ # obtain Tensor size of each rank
+ local_size = torch.tensor([tensor.numel()], device="cuda")
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
+ if local_size != max_size:
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
+ tensor = torch.cat((tensor, padding), dim=0)
+ dist.all_gather(tensor_list, tensor)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list, strict=False):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Args:
+ input_dict (dict): all the values will be reduced
+ average (bool): whether to do average or sum
+ Reduce the values in the dictionary from all processes so that all processes
+ have the averaged results. Returns a dict with the same fields as
+ input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.all_reduce(values)
+ if average:
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416
+ return reduced_dict
+
+
+class MetricLogger:
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append("{}: {}".format(name, str(meter)))
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ if not header:
+ header = ""
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
+ data_time = SmoothedValue(fmt="{avg:.4f}")
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
+ if torch.cuda.is_available():
+ log_msg = self.delimiter.join(
+ [
+ header,
+ "[{0" + space_fmt + "}/{1}]",
+ "eta: {eta}",
+ "{meters}",
+ "time: {time}",
+ "data: {data}",
+ "max mem: {memory:.0f}",
+ ]
+ )
+ else:
+ log_msg = self.delimiter.join(
+ [
+ header,
+ "[{0" + space_fmt + "}/{1}]",
+ "eta: {eta}",
+ "{meters}",
+ "time: {time}",
+ "data: {data}",
+ ]
+ )
+ mega_b = 1024.0 * 1024.0
+ for i, obj in enumerate(iterable):
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(
+ log_msg.format(
+ i,
+ len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / mega_b,
+ )
+ )
+ else:
+ print(
+ log_msg.format(
+ i,
+ len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ )
+ )
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
+
+
+def get_sha():
+ cwd = os.path.dirname(os.path.abspath(__file__))
+
+ def _run(command):
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
+
+ sha = "N/A"
+ diff = "clean"
+ branch = "N/A"
+ try:
+ sha = _run(["git", "rev-parse", "HEAD"])
+ subprocess.check_output(["git", "diff"], cwd=cwd)
+ diff = _run(["git", "diff-index", "HEAD"])
+ diff = "has uncommited changes" if diff else "clean"
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
+ except Exception:
+ pass
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
+ return message
+
+
+def collate_fn(batch):
+ batch = list(zip(*batch, strict=False))
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
+ return tuple(batch)
+
+
+def _max_by_axis(the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+
+class NestedTensor:
+ def __init__(self, tensors, mask: Optional[Tensor]):
+ self.tensors = tensors
+ self.mask = mask
+
+ def to(self, device):
+ # type: (Device) -> NestedTensor # noqa
+ cast_tensor = self.tensors.to(device)
+ mask = self.mask
+ if mask is not None:
+ assert mask is not None
+ cast_mask = mask.to(device)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
+
+
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
+ # TODO make this more general
+ if tensor_list[0].ndim == 3:
+ if torchvision._is_tracing():
+ # nested_tensor_from_tensor_list() does not export well to ONNX
+ # call _onnx_nested_tensor_from_tensor_list() instead
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
+
+ # TODO make it support different-sized images
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, h, w = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ m[: img.shape[1], : img.shape[2]] = False
+ else:
+ raise ValueError("not supported")
+ return NestedTensor(tensor, mask)
+
+
+# _onnx_nested_tensor_from_tensor_list() is an implementation of
+# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
+@torch.jit.unused
+def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
+ max_size = []
+ for i in range(tensor_list[0].dim()):
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(
+ torch.int64
+ )
+ max_size.append(max_size_i)
+ max_size = tuple(max_size)
+
+ # work around for
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ # m[: img.shape[1], :img.shape[2]] = False
+ # which is not yet supported in onnx
+ padded_imgs = []
+ padded_masks = []
+ for img in tensor_list:
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+ padded_imgs.append(padded_img)
+
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
+ padded_masks.append(padded_mask.to(torch.bool))
+
+ tensor = torch.stack(padded_imgs)
+ mask = torch.stack(padded_masks)
+
+ return NestedTensor(tensor, mask=mask)
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ["WORLD_SIZE"])
+ args.gpu = int(os.environ["LOCAL_RANK"])
+ elif "SLURM_PROCID" in os.environ:
+ args.rank = int(os.environ["SLURM_PROCID"])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print("Not using distributed mode")
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = "nccl"
+ print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
+ torch.distributed.init_process_group(
+ backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
+ )
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+@torch.no_grad()
+def accuracy(output, target, topk=(1,)):
+ """Computes the precision@k for the specified values of k"""
+ if target.numel() == 0:
+ return [torch.zeros([], device=output.device)]
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
+ """
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
+ This will eventually be supported natively by PyTorch, and this
+ class can go away.
+ """
+ if version.parse(torchvision.__version__) < version.parse("0.7"):
+ if input.numel() > 0:
+ return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
+
+ output_shape = _output_size(2, input, size, scale_factor)
+ output_shape = list(input.shape[:-2]) + list(output_shape)
+ return _new_empty_tensor(input, output_shape)
+ else:
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py
index a956cb4b..c5e45300 100644
--- a/lerobot/common/policies/factory.py
+++ b/lerobot/common/policies/factory.py
@@ -17,6 +17,12 @@ def make_policy(cfg):
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)
+ elif cfg.policy.name == "act":
+ from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
+
+ policy = ActionChunkingTransformerPolicy(
+ cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
+ )
else:
raise ValueError(cfg.policy.name)
diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml
index 5b5ecbb7..df464c75 100644
--- a/lerobot/configs/env/aloha.yaml
+++ b/lerobot/configs/env/aloha.yaml
@@ -15,11 +15,11 @@ env:
task: sim_insertion_human
from_pixels: True
pixels_only: False
- image_size: 96
+ image_size: [3, 480, 640]
action_repeat: 1
- episode_length: 300
+ episode_length: 400
fps: ${fps}
policy:
- state_dim: 2
- action_dim: 2
+ state_dim: 14
+ action_dim: 14
diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml
new file mode 100644
index 00000000..a52c3f54
--- /dev/null
+++ b/lerobot/configs/policy/act.yaml
@@ -0,0 +1,58 @@
+# @package _global_
+
+offline_steps: 1344000
+online_steps: 0
+
+eval_episodes: 1
+eval_freq: 10000
+save_freq: 100000
+log_freq: 250
+
+horizon: 100
+n_obs_steps: 1
+n_latency_steps: 0
+# when temporal_agg=False, n_action_steps=horizon
+n_action_steps: ${horizon}
+
+policy:
+ name: act
+
+ pretrained_model_path:
+
+ lr: 1e-5
+ lr_backbone: 1e-5
+ weight_decay: 1e-4
+ grad_clip_norm: 10
+ backbone: resnet18
+ num_queries: ${horizon} # chunk_size
+ horizon: ${horizon} # chunk_size
+ kl_weight: 10
+ hidden_dim: 512
+ dim_feedforward: 3200
+ enc_layers: 4
+ dec_layers: 7
+ nheads: 8
+ #camera_names: [top, front_close, left_pillar, right_pillar]
+ camera_names: [top]
+ position_embedding: sine
+ masks: false
+ dilation: false
+ dropout: 0.1
+ pre_norm: false
+
+ vae: true
+
+ batch_size: 8
+
+ per_alpha: 0.6
+ per_beta: 0.4
+
+ balanced_sampling: false
+ utd: 1
+
+ n_obs_steps: ${n_obs_steps}
+
+ temporal_agg: false
+
+ state_dim: ???
+ action_dim: ???
diff --git a/lerobot/scripts/download.py b/lerobot/scripts/download.py
deleted file mode 100644
index ac935f48..00000000
--- a/lerobot/scripts/download.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# TODO(rcadene): obsolete remove
-import os
-import zipfile
-
-import gdown
-
-
-def download():
- url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
- download_path = "data.zip"
- gdown.download(url, download_path, quiet=False)
- print("Extracting...")
- with zipfile.ZipFile(download_path, "r") as zip_f:
- for member in zip_f.namelist():
- if member.startswith("data/xarm") and member.endswith(".pkl"):
- print(member)
- zip_f.extract(member=member)
- os.remove(download_path)
-
-
-if __name__ == "__main__":
- download()
diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py
index 6435310a..7ba2812e 100644
--- a/lerobot/scripts/eval.py
+++ b/lerobot/scripts/eval.py
@@ -38,27 +38,18 @@ def eval_policy(
successes = []
threads = []
for i in tqdm.tqdm(range(num_episodes)):
- tensordict = env.reset()
-
ep_frames = []
-
if save_video or (return_first_video and i == 0):
- def rendering_callback(env, td=None):
+ def render_frame(env):
ep_frames.append(env.render()) # noqa: B023
- # render first frame before rollout
- rendering_callback(env)
- else:
- rendering_callback = None
+ env.register_rendering_hook(render_frame)
with torch.inference_mode():
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
- callback=rendering_callback,
- auto_reset=False,
- tensordict=tensordict,
auto_cast_to_device=True,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
@@ -85,6 +76,8 @@ def eval_policy(
if return_first_video and i == 0:
first_video = stacked_frames.transpose(0, 3, 1, 2)
+ env.reset_rendering_hooks()
+
for thread in threads:
thread.join()
diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py
index f4b22604..c063caf8 100644
--- a/lerobot/scripts/train.py
+++ b/lerobot/scripts/train.py
@@ -1,4 +1,5 @@
import logging
+from pathlib import Path
import hydra
import numpy as np
@@ -192,6 +193,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
return_first_video=True,
+ video_dir=Path(out_dir) / "eval",
+ save_video=True,
)
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
if cfg.wandb.enable:
diff --git a/poetry.lock b/poetry.lock
index db4f8f3e..59de0ec5 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -340,69 +340,69 @@ files = [
[[package]]
name = "cython"
-version = "3.0.8"
+version = "3.0.9"
description = "The Cython compiler for writing C extensions in the Python language."
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
- {file = "Cython-3.0.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a846e0a38e2b24e9a5c5dc74b0e54c6e29420d88d1dafabc99e0fc0f3e338636"},
- {file = "Cython-3.0.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45523fdc2b78d79b32834cc1cc12dc2ca8967af87e22a3ee1bff20e77c7f5520"},
- {file = "Cython-3.0.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa0b7f3f841fe087410cab66778e2d3fb20ae2d2078a2be3dffe66c6574be39"},
- {file = "Cython-3.0.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e87294e33e40c289c77a135f491cd721bd089f193f956f7b8ed5aa2d0b8c558f"},
- {file = "Cython-3.0.8-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:a1df7a129344b1215c20096d33c00193437df1a8fcca25b71f17c23b1a44f782"},
- {file = "Cython-3.0.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:13c2a5e57a0358da467d97667297bf820b62a1a87ae47c5f87938b9bb593acbd"},
- {file = "Cython-3.0.8-cp310-cp310-win32.whl", hash = "sha256:96b028f044f5880e3cb18ecdcfc6c8d3ce9d0af28418d5ab464509f26d8adf12"},
- {file = "Cython-3.0.8-cp310-cp310-win_amd64.whl", hash = "sha256:8140597a8b5cc4f119a1190f5a2228a84f5ca6d8d9ec386cfce24663f48b2539"},
- {file = "Cython-3.0.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aae26f9663e50caf9657148403d9874eea41770ecdd6caf381d177c2b1bb82ba"},
- {file = "Cython-3.0.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:547eb3cdb2f8c6f48e6865d5a741d9dd051c25b3ce076fbca571727977b28ac3"},
- {file = "Cython-3.0.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a567d4b9ba70b26db89d75b243529de9e649a2f56384287533cf91512705bee"},
- {file = "Cython-3.0.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:51d1426263b0e82fb22bda8ea60dc77a428581cc19e97741011b938445d383f1"},
- {file = "Cython-3.0.8-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c26daaeccda072459b48d211415fd1e5507c06bcd976fa0d5b8b9f1063467d7b"},
- {file = "Cython-3.0.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:289ce7838208211cd166e975865fd73b0649bf118170b6cebaedfbdaf4a37795"},
- {file = "Cython-3.0.8-cp311-cp311-win32.whl", hash = "sha256:c8aa05f5e17f8042a3be052c24f2edc013fb8af874b0bf76907d16c51b4e7871"},
- {file = "Cython-3.0.8-cp311-cp311-win_amd64.whl", hash = "sha256:000dc9e135d0eec6ecb2b40a5b02d0868a2f8d2e027a41b0fe16a908a9e6de02"},
- {file = "Cython-3.0.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:90d3fe31db55685d8cb97d43b0ec39ef614fcf660f83c77ed06aa670cb0e164f"},
- {file = "Cython-3.0.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e24791ddae2324e88e3c902a765595c738f19ae34ee66bfb1a6dac54b1833419"},
- {file = "Cython-3.0.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f020fa1c0552052e0660790b8153b79e3fc9a15dbd8f1d0b841fe5d204a6ae6"},
- {file = "Cython-3.0.8-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18bfa387d7a7f77d7b2526af69a65dbd0b731b8d941aaff5becff8e21f6d7717"},
- {file = "Cython-3.0.8-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fe81b339cffd87c0069c6049b4d33e28bdd1874625ee515785bf42c9fdff3658"},
- {file = "Cython-3.0.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:80fd94c076e1e1b1ee40a309be03080b75f413e8997cddcf401a118879863388"},
- {file = "Cython-3.0.8-cp312-cp312-win32.whl", hash = "sha256:85077915a93e359a9b920280d214dc0cf8a62773e1f3d7d30fab8ea4daed670c"},
- {file = "Cython-3.0.8-cp312-cp312-win_amd64.whl", hash = "sha256:0cb2dcc565c7851f75d496f724a384a790fab12d1b82461b663e66605bec429a"},
- {file = "Cython-3.0.8-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:870d2a0a7e3cbd5efa65aecdb38d715ea337a904ea7bb22324036e78fb7068e7"},
- {file = "Cython-3.0.8-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e8f2454128974905258d86534f4fd4f91d2f1343605657ecab779d80c9d6d5e"},
- {file = "Cython-3.0.8-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1949d6aa7bc792554bee2b67a9fe41008acbfe22f4f8df7b6ec7b799613a4b3"},
- {file = "Cython-3.0.8-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c9f2c6e1b8f3bcd6cb230bac1843f85114780bb8be8614855b1628b36bb510e0"},
- {file = "Cython-3.0.8-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:05d7eddc668ae7993643f32c7661f25544e791edb745758672ea5b1a82ecffa6"},
- {file = "Cython-3.0.8-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bfabe115deef4ada5d23c87bddb11289123336dcc14347011832c07db616dd93"},
- {file = "Cython-3.0.8-cp36-cp36m-win32.whl", hash = "sha256:0c38c9f0bcce2df0c3347285863621be904ac6b64c5792d871130569d893efd7"},
- {file = "Cython-3.0.8-cp36-cp36m-win_amd64.whl", hash = "sha256:6c46939c3983217d140999de7c238c3141f56b1ea349e47ca49cae899969aa2c"},
- {file = "Cython-3.0.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:115f0a50f752da6c99941b103b5cb090da63eb206abbc7c2ad33856ffc73f064"},
- {file = "Cython-3.0.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9c0f29246734561c90f36e70ed0506b61aa3d044e4cc4cba559065a2a741fae"},
- {file = "Cython-3.0.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ab75242869ff71e5665fe5c96f3378e79e792fa3c11762641b6c5afbbbbe026"},
- {file = "Cython-3.0.8-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6717c06e9cfc6c1df18543cd31a21f5d8e378a40f70c851fa2d34f0597037abc"},
- {file = "Cython-3.0.8-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:9d3f74388db378a3c6fd06e79a809ed98df3f56484d317b81ee762dbf3c263e0"},
- {file = "Cython-3.0.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ae7ac561fd8253a9ae96311e91d12af5f701383564edc11d6338a7b60b285a6f"},
- {file = "Cython-3.0.8-cp37-cp37m-win32.whl", hash = "sha256:97b2a45845b993304f1799664fa88da676ee19442b15fdcaa31f9da7e1acc434"},
- {file = "Cython-3.0.8-cp37-cp37m-win_amd64.whl", hash = "sha256:9e2be2b340fea46fb849d378f9b80d3c08ff2e81e2bfbcdb656e2e3cd8c6b2dc"},
- {file = "Cython-3.0.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2cde23c555470db3f149ede78b518e8274853745289c956a0e06ad8d982e4db9"},
- {file = "Cython-3.0.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7990ca127e1f1beedaf8fc8bf66541d066ef4723ad7d8d47a7cbf842e0f47580"},
- {file = "Cython-3.0.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b983c8e6803f016146c26854d9150ddad5662960c804ea7f0c752c9266752f0"},
- {file = "Cython-3.0.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a973268d7ca1a2bdf78575e459a94a78e1a0a9bb62a7db0c50041949a73b02ff"},
- {file = "Cython-3.0.8-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:61a237bc9dd23c7faef0fcfce88c11c65d0c9bb73c74ccfa408b3a012073c20e"},
- {file = "Cython-3.0.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:3a3d67f079598af49e90ff9655bf85bd358f093d727eb21ca2708f467c489cae"},
- {file = "Cython-3.0.8-cp38-cp38-win32.whl", hash = "sha256:17a642bb01a693e34c914106566f59844b4461665066613913463a719e0dd15d"},
- {file = "Cython-3.0.8-cp38-cp38-win_amd64.whl", hash = "sha256:2cdfc32252f3b6dc7c94032ab744dcedb45286733443c294d8f909a4854e7f83"},
- {file = "Cython-3.0.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fa97893d99385386925d00074654aeae3a98867f298d1e12ceaf38a9054a9bae"},
- {file = "Cython-3.0.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f05c0bf9d085c031df8f583f0d506aa3be1692023de18c45d0aaf78685bbb944"},
- {file = "Cython-3.0.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de892422582f5758bd8de187e98ac829330ec1007bc42c661f687792999988a7"},
- {file = "Cython-3.0.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:314f2355a1f1d06e3c431eaad4708cf10037b5e91e4b231d89c913989d0bdafd"},
- {file = "Cython-3.0.8-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:78825a3774211e7d5089730f00cdf7f473042acc9ceb8b9eeebe13ed3a5541de"},
- {file = "Cython-3.0.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:df8093deabc55f37028190cf5e575c26aad23fc673f34b85d5f45076bc37ce39"},
- {file = "Cython-3.0.8-cp39-cp39-win32.whl", hash = "sha256:1aca1b97e0095b3a9a6c33eada3f661a4ed0d499067d121239b193e5ba3bb4f0"},
- {file = "Cython-3.0.8-cp39-cp39-win_amd64.whl", hash = "sha256:16873d78be63bd38ffb759da7ab82814b36f56c769ee02b1d5859560e4c3ac3c"},
- {file = "Cython-3.0.8-py2.py3-none-any.whl", hash = "sha256:171b27051253d3f9108e9759e504ba59ff06e7f7ba944457f94deaf9c21bf0b6"},
- {file = "Cython-3.0.8.tar.gz", hash = "sha256:8333423d8fd5765e7cceea3a9985dd1e0a5dfeb2734629e1a2ed2d6233d39de6"},
+ {file = "Cython-3.0.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:296bd30d4445ac61b66c9d766567f6e81a6e262835d261e903c60c891a6729d3"},
+ {file = "Cython-3.0.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f496b52845cb45568a69d6359a2c335135233003e708ea02155c10ce3548aa89"},
+ {file = "Cython-3.0.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:858c3766b9aa3ab8a413392c72bbab1c144a9766b7c7bfdef64e2e414363fa0c"},
+ {file = "Cython-3.0.9-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c0eb1e6ef036028a52525fd9a012a556f6dd4788a0e8755fe864ba0e70cde2ff"},
+ {file = "Cython-3.0.9-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c8191941073ea5896321de3c8c958fd66e5f304b0cd1f22c59edd0b86c4dd90d"},
+ {file = "Cython-3.0.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e32b016030bc72a8a22a1f21f470a2f57573761a4f00fbfe8347263f4fbdb9f1"},
+ {file = "Cython-3.0.9-cp310-cp310-win32.whl", hash = "sha256:d6f3ff1cd6123973fe03e0fb8ee936622f976c0c41138969975824d08886572b"},
+ {file = "Cython-3.0.9-cp310-cp310-win_amd64.whl", hash = "sha256:56f3b643dbe14449248bbeb9a63fe3878a24256664bc8c8ef6efd45d102596d8"},
+ {file = "Cython-3.0.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:35e6665a20d6b8a152d72b7fd87dbb2af6bb6b18a235b71add68122d594dbd41"},
+ {file = "Cython-3.0.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f92f4960c40ad027bd8c364c50db11104eadc59ffeb9e5b7f605ca2f05946e20"},
+ {file = "Cython-3.0.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38df37d0e732fbd9a2fef898788492e82b770c33d1e4ed12444bbc8a3b3f89c0"},
+ {file = "Cython-3.0.9-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad7fd88ebaeaf2e76fd729a8919fae80dab3d6ac0005e28494261d52ff347a8f"},
+ {file = "Cython-3.0.9-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1365d5f76bf4d19df3d19ce932584c9bb76e9fb096185168918ef9b36e06bfa4"},
+ {file = "Cython-3.0.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c232e7f279388ac9625c3e5a5a9f0078a9334959c5d6458052c65bbbba895e1e"},
+ {file = "Cython-3.0.9-cp311-cp311-win32.whl", hash = "sha256:357e2fad46a25030b0c0496487e01a9dc0fdd0c09df0897f554d8ba3c1bc4872"},
+ {file = "Cython-3.0.9-cp311-cp311-win_amd64.whl", hash = "sha256:1315aee506506e8d69cf6631d8769e6b10131fdcc0eb66df2698f2a3ddaeeff2"},
+ {file = "Cython-3.0.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:157973807c2796addbed5fbc4d9c882ab34bbc60dc297ca729504901479d5df7"},
+ {file = "Cython-3.0.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00b105b5d050645dd59e6767bc0f18b48a4aa11c85f42ec7dd8181606f4059e3"},
+ {file = "Cython-3.0.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac5536d09bef240cae0416d5a703d298b74c7bbc397da803ac9d344e732d4369"},
+ {file = "Cython-3.0.9-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09c44501d476d16aaa4cbc29c87f8c0f54fc20e69b650d59cbfa4863426fc70c"},
+ {file = "Cython-3.0.9-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:cc9c3b9f20d8e298618e5ccd32083ca386e785b08f9893fbec4c50b6b85be772"},
+ {file = "Cython-3.0.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a30d96938c633e3ec37000ac3796525da71254ef109e66bdfd78f29891af6454"},
+ {file = "Cython-3.0.9-cp312-cp312-win32.whl", hash = "sha256:757ca93bdd80702546df4d610d2494ef2e74249cac4d5ba9464589fb464bd8a3"},
+ {file = "Cython-3.0.9-cp312-cp312-win_amd64.whl", hash = "sha256:1dc320a9905ab95414013f6de805efbff9e17bb5fb3b90bbac533f017bec8136"},
+ {file = "Cython-3.0.9-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4ae349960ebe0da0d33724eaa7f1eb866688fe5434cc67ce4dbc06d6a719fbfc"},
+ {file = "Cython-3.0.9-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63d2537bf688247f76ded6dee28ebd26274f019309aef1eb4f2f9c5c482fde2d"},
+ {file = "Cython-3.0.9-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36f5a2dfc724bea1f710b649f02d802d80fc18320c8e6396684ba4a48412445a"},
+ {file = "Cython-3.0.9-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:deaf4197d4b0bcd5714a497158ea96a2bd6d0f9636095437448f7e06453cc83d"},
+ {file = "Cython-3.0.9-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:000af6deb7412eb7ac0c635ff5e637fb8725dd0a7b88cc58dfc2b3de14e701c4"},
+ {file = "Cython-3.0.9-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:15c7f5c2d35bed9aa5f2a51eaac0df23ae72f2dbacf62fc672dd6bfaa75d2d6f"},
+ {file = "Cython-3.0.9-cp36-cp36m-win32.whl", hash = "sha256:f49aa4970cd3bec66ac22e701def16dca2a49c59cceba519898dd7526e0be2c0"},
+ {file = "Cython-3.0.9-cp36-cp36m-win_amd64.whl", hash = "sha256:4558814fa025b193058d42eeee498a53d6b04b2980d01339fc2444b23fd98e58"},
+ {file = "Cython-3.0.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:539cd1d74fd61f6cfc310fa6bbbad5adc144627f2b7486a07075d4e002fd6aad"},
+ {file = "Cython-3.0.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3232926cd406ee02eabb732206f6e882c3aed9d58f0fea764013d9240405bcf"},
+ {file = "Cython-3.0.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33b6ac376538a7fc8c567b85d3c71504308a9318702ec0485dd66c059f3165cb"},
+ {file = "Cython-3.0.9-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2cc92504b5d22ac66031ffb827bd3a967fc75a5f0f76ab48bce62df19be6fdfd"},
+ {file = "Cython-3.0.9-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:22b8fae756c5c0d8968691bed520876de452f216c28ec896a00739a12dba3bd9"},
+ {file = "Cython-3.0.9-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:9cda0d92a09f3520f29bd91009f1194ba9600777c02c30c6d2d4ac65fb63e40d"},
+ {file = "Cython-3.0.9-cp37-cp37m-win32.whl", hash = "sha256:ec612418490941ed16c50c8d3784c7bdc4c4b2a10c361259871790b02ec8c1db"},
+ {file = "Cython-3.0.9-cp37-cp37m-win_amd64.whl", hash = "sha256:976c8d2bedc91ff6493fc973d38b2dc01020324039e2af0e049704a8e1b22936"},
+ {file = "Cython-3.0.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5055988b007c92256b6e9896441c3055556038c3497fcbf8c921a6c1fce90719"},
+ {file = "Cython-3.0.9-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9360606d964c2d0492a866464efcf9d0a92715644eede3f6a2aa696de54a137"},
+ {file = "Cython-3.0.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02c6e809f060bed073dc7cba1648077fe3b68208863d517c8b39f3920eecf9dd"},
+ {file = "Cython-3.0.9-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:95ed792c966f969cea7489c32ff90150b415c1f3567db8d5a9d489c7c1602dac"},
+ {file = "Cython-3.0.9-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8edd59d22950b400b03ca78d27dc694d2836a92ef0cac4f64cb4b2ff902f7e25"},
+ {file = "Cython-3.0.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4cf0ed273bf60e97922fcbbdd380c39693922a597760160b4b4355e6078ca188"},
+ {file = "Cython-3.0.9-cp38-cp38-win32.whl", hash = "sha256:5eb9bd4ae12ebb2bc79a193d95aacf090fbd8d7013e11ed5412711650cb34934"},
+ {file = "Cython-3.0.9-cp38-cp38-win_amd64.whl", hash = "sha256:44457279da56e0f829bb1fc5a5dc0836e5d498dbcf9b2324f32f7cc9d2ec6569"},
+ {file = "Cython-3.0.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c4b419a1adc2af43f4660e2f6eaf1e4fac2dbac59490771eb8ac3d6063f22356"},
+ {file = "Cython-3.0.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f836192140f033b2319a0128936367c295c2b32e23df05b03b672a6015757ea"},
+ {file = "Cython-3.0.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fd198c1a7f8e9382904d622cc0efa3c184605881fd5262c64cbb7168c4c1ec5"},
+ {file = "Cython-3.0.9-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a274fe9ca5c53fafbcf5c8f262f8ad6896206a466f0eeb40aaf36a7951e957c0"},
+ {file = "Cython-3.0.9-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:158c38360bbc5063341b1e78d3737f1251050f89f58a3df0d10fb171c44262be"},
+ {file = "Cython-3.0.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8bf30b045f7deda0014b042c1b41c1d272facc762ab657529e3b05505888e878"},
+ {file = "Cython-3.0.9-cp39-cp39-win32.whl", hash = "sha256:9a001fd95c140c94d934078544ff60a3c46aca2dc86e75a76e4121d3cd1f4b33"},
+ {file = "Cython-3.0.9-cp39-cp39-win_amd64.whl", hash = "sha256:530c01c4aebba709c0ec9c7ecefe07177d0b9fd7ffee29450a118d92192ccbdf"},
+ {file = "Cython-3.0.9-py2.py3-none-any.whl", hash = "sha256:bf96417714353c5454c2e3238fca9338599330cf51625cdc1ca698684465646f"},
+ {file = "Cython-3.0.9.tar.gz", hash = "sha256:a2d354f059d1f055d34cfaa62c5b68bc78ac2ceab6407148d47fb508cf3ba4f3"},
]
[[package]]
@@ -488,6 +488,37 @@ files = [
{file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"},
]
+[[package]]
+name = "dm-control"
+version = "1.0.14"
+description = "Continuous control environments and MuJoCo Python bindings."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "dm_control-1.0.14-py3-none-any.whl", hash = "sha256:883c63244a7ebf598700a97564ed19fffd3479ca79efd090aed881609cdb9fc6"},
+ {file = "dm_control-1.0.14.tar.gz", hash = "sha256:def1ece747b6f175c581150826b50f1a6134086dab34f8f3fd2d088ea035cf3d"},
+]
+
+[package.dependencies]
+absl-py = ">=0.7.0"
+dm-env = "*"
+dm-tree = "!=0.1.2"
+glfw = "*"
+labmaze = "*"
+lxml = "*"
+mujoco = ">=2.3.7"
+numpy = ">=1.9.0"
+protobuf = ">=3.19.4"
+pyopengl = ">=3.1.4"
+pyparsing = ">=3.0.0"
+requests = "*"
+scipy = "*"
+setuptools = "!=50.0.0"
+tqdm = "*"
+
+[package.extras]
+hdf5 = ["h5py"]
+
[[package]]
name = "dm-env"
version = "1.6"
@@ -584,43 +615,6 @@ files = [
{file = "einops-0.7.0.tar.gz", hash = "sha256:b2b04ad6081a3b227080c9bf5e3ace7160357ff03043cd66cc5b2319eb7031d1"},
]
-[[package]]
-name = "etils"
-version = "1.7.0"
-description = "Collection of common python utils"
-optional = false
-python-versions = ">=3.10"
-files = [
- {file = "etils-1.7.0-py3-none-any.whl", hash = "sha256:61af8f7c242171de15e22e5da02d527cb9e677d11f8bcafe18fcc3548eee3e60"},
- {file = "etils-1.7.0.tar.gz", hash = "sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350"},
-]
-
-[package.dependencies]
-fsspec = {version = "*", optional = true, markers = "extra == \"epath\""}
-importlib_resources = {version = "*", optional = true, markers = "extra == \"epath\""}
-typing_extensions = {version = "*", optional = true, markers = "extra == \"epy\""}
-zipp = {version = "*", optional = true, markers = "extra == \"epath\""}
-
-[package.extras]
-all = ["etils[array-types]", "etils[eapp]", "etils[ecolab]", "etils[edc]", "etils[enp]", "etils[epath-gcs]", "etils[epath-s3]", "etils[epath]", "etils[epy]", "etils[etqdm]", "etils[etree-dm]", "etils[etree-jax]", "etils[etree-tf]", "etils[etree]"]
-array-types = ["etils[enp]"]
-dev = ["chex", "dataclass_array", "optree", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-subtests", "pytest-xdist", "torch"]
-docs = ["etils[all,dev]", "sphinx-apitree[ext]"]
-eapp = ["absl-py", "etils[epy]", "simple_parsing"]
-ecolab = ["etils[enp]", "etils[epy]", "etils[etree]", "jupyter", "mediapy", "numpy", "packaging", "protobuf"]
-edc = ["etils[epy]"]
-enp = ["etils[epy]", "numpy"]
-epath = ["etils[epy]", "fsspec", "importlib_resources", "typing_extensions", "zipp"]
-epath-gcs = ["etils[epath]", "gcsfs"]
-epath-s3 = ["etils[epath]", "s3fs"]
-epy = ["typing_extensions"]
-etqdm = ["absl-py", "etils[epy]", "tqdm"]
-etree = ["etils[array-types]", "etils[enp]", "etils[epy]", "etils[etqdm]"]
-etree-dm = ["dm-tree", "etils[etree]"]
-etree-jax = ["etils[etree]", "jax[cpu]"]
-etree-tf = ["etils[etree]", "tensorflow"]
-lazy-imports = ["etils[ecolab]"]
-
[[package]]
name = "exceptiongroup"
version = "1.2.0"
@@ -846,13 +840,13 @@ numpy = ">=1.17.3"
[[package]]
name = "huggingface-hub"
-version = "0.21.3"
+version = "0.21.4"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "huggingface_hub-0.21.3-py3-none-any.whl", hash = "sha256:b183144336fdf2810a8c109822e0bb6ef1fd61c65da6fb60e8c3f658b7144016"},
- {file = "huggingface_hub-0.21.3.tar.gz", hash = "sha256:26a15b604e4fc7bad37c467b76456543ec849386cbca9cd7e1e135f53e500423"},
+ {file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"},
+ {file = "huggingface_hub-0.21.4.tar.gz", hash = "sha256:e1f4968c93726565a80edf6dc309763c7b546d0cfe79aa221206034d50155531"},
]
[package.dependencies]
@@ -971,37 +965,22 @@ setuptools = "*"
[[package]]
name = "importlib-metadata"
-version = "7.0.1"
+version = "7.0.2"
description = "Read metadata from Python packages"
optional = false
python-versions = ">=3.8"
files = [
- {file = "importlib_metadata-7.0.1-py3-none-any.whl", hash = "sha256:4805911c3a4ec7c3966410053e9ec6a1fecd629117df5adee56dfc9432a1081e"},
- {file = "importlib_metadata-7.0.1.tar.gz", hash = "sha256:f238736bb06590ae52ac1fab06a3a9ef1d8dce2b7a35b5ab329371d6c8f5d2cc"},
+ {file = "importlib_metadata-7.0.2-py3-none-any.whl", hash = "sha256:f4bc4c0c070c490abf4ce96d715f68e95923320370efb66143df00199bb6c100"},
+ {file = "importlib_metadata-7.0.2.tar.gz", hash = "sha256:198f568f3230878cb1b44fbd7975f87906c22336dba2e4a7f05278c281fbd792"},
]
[package.dependencies]
zipp = ">=0.5"
[package.extras]
-docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
perf = ["ipython"]
-testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
-
-[[package]]
-name = "importlib-resources"
-version = "6.1.2"
-description = "Read resources from Python packages"
-optional = false
-python-versions = ">=3.8"
-files = [
- {file = "importlib_resources-6.1.2-py3-none-any.whl", hash = "sha256:9a0a862501dc38b68adebc82970140c9e4209fc99601782925178f8386339938"},
- {file = "importlib_resources-6.1.2.tar.gz", hash = "sha256:308abf8474e2dba5f867d279237cd4076482c3de7104a40b41426370e891549b"},
-]
-
-[package.extras]
-docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
-testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"]
+testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
[[package]]
name = "iniconfig"
@@ -1031,6 +1010,50 @@ MarkupSafe = ">=2.0"
[package.extras]
i18n = ["Babel (>=2.7)"]
+[[package]]
+name = "labmaze"
+version = "1.0.6"
+description = "LabMaze: DeepMind Lab's text maze generator."
+optional = false
+python-versions = "*"
+files = [
+ {file = "labmaze-1.0.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b2ddef976dfd8d992b19cfa6c633f2eba7576d759c2082da534e3f727479a84a"},
+ {file = "labmaze-1.0.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:157efaa93228c8ccce5cae337902dd652093e0fba9d3a0f6506e4bee272bb66f"},
+ {file = "labmaze-1.0.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3ce98b9541c5fe6a306e411e7d018121dd646f2c9978d763fad86f9f30c5f57"},
+ {file = "labmaze-1.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e6433bd49bc541791de8191040526fddfebb77151620eb04203453f43ee486a"},
+ {file = "labmaze-1.0.6-cp310-cp310-win_amd64.whl", hash = "sha256:6a507fc35961f1b1479708e2716f65e0d0611cefb55f31a77be29ce2339b6fef"},
+ {file = "labmaze-1.0.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:a0c2cb9dec971814ea9c5d7150af15fa3964482131fa969e0afb94bd224348af"},
+ {file = "labmaze-1.0.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c6ba9538d819543f4be448d36b4926a3881e53646a2b331ebb5a1f353047d05"},
+ {file = "labmaze-1.0.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70635d1cdb0147a02efb6b3f607a52cdc51723bc3dcc42717a0d4ef55fa0a987"},
+ {file = "labmaze-1.0.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff472793238bd9b6dabea8094594d6074ad3c111455de3afcae72f6c40c6817e"},
+ {file = "labmaze-1.0.6-cp311-cp311-win_amd64.whl", hash = "sha256:2317e65e12fa3d1abecda7e0488dab15456cee8a2e717a586bfc8f02a91579e7"},
+ {file = "labmaze-1.0.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:e36b6fadcd78f22057b597c1c77823e806a0987b3bdfbf850e14b6b5b502075e"},
+ {file = "labmaze-1.0.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d1a4f8de29c2c3d7f14163759b69cd3f237093b85334c983619c1db5403a223b"},
+ {file = "labmaze-1.0.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a394f8bb857fcaa2884b809d63e750841c2662a106cfe8c045f2112d201ac7d5"},
+ {file = "labmaze-1.0.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d17abb69d4dfc56183afb5c317e8b2eaca0587abb3aabd2326efd3143c81f4e"},
+ {file = "labmaze-1.0.6-cp312-cp312-win_amd64.whl", hash = "sha256:5af997598cc46b1929d1c5a1febc32fd56c75874fe481a2a5982c65cee8450c9"},
+ {file = "labmaze-1.0.6-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:a4c5bc6e56baa55ce63b97569afec2f80cab0f6b952752a131e1f83eed190a53"},
+ {file = "labmaze-1.0.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3955f24fe5f708e1e97495b4cfe284b70ae4fd51be5e17b75a6fc04ffbd67bca"},
+ {file = "labmaze-1.0.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed96ddc0bb8d66df36428c94db83949fd84a15867e8250763a4c5e3d82104c54"},
+ {file = "labmaze-1.0.6-cp37-cp37m-win_amd64.whl", hash = "sha256:3bd0458a29e55aa09f146e28a168d2e00b8ccf19e2259a3f71154cfff3536b1d"},
+ {file = "labmaze-1.0.6-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:33f5154edc83dff55a150e54b60c8582fdafc7ec45195049809cbcc01f5e8f34"},
+ {file = "labmaze-1.0.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0971055ef2a5f7d8517fdc42b67c057093698f1eb911f46faa7018867b73fcc9"},
+ {file = "labmaze-1.0.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de18d09680007302abf49111f3fe822d8435e4fbc4468b9ec07d50a78e267865"},
+ {file = "labmaze-1.0.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f18126066db2218a52853c7dd490b4c3d8129fc22eb3a47eb23007524b911d53"},
+ {file = "labmaze-1.0.6-cp38-cp38-win_amd64.whl", hash = "sha256:f9aef09a76877342bb4d634b7e05f43b038a49c4f34adfb8f1b8ac57c29472f2"},
+ {file = "labmaze-1.0.6-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5dd28899418f1b8b1c7d1e1b40a4593150a7cfa95ca91e23860b9785b82cc0ee"},
+ {file = "labmaze-1.0.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:965569f37ee33090b4d4b3aa5aa7c9dcc4f62e2ae5d761e7f73ec76fc9d8aa96"},
+ {file = "labmaze-1.0.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05eccfa98c0e781bc9f939076ae600b2e25ca736e123f2a530606aedec3b531c"},
+ {file = "labmaze-1.0.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bee8c94e0fb3fc2d8180214947245c1d74a3489349a9da90b868296e77a521e9"},
+ {file = "labmaze-1.0.6-cp39-cp39-win_amd64.whl", hash = "sha256:d486e9ca3a335ad628e3bd48a09c42f1aa5f51040952ef0fe32507afedcd694b"},
+ {file = "labmaze-1.0.6.tar.gz", hash = "sha256:2e8de7094042a77d6972f1965cf5c9e8f971f1b34d225752f343190a825ebe73"},
+]
+
+[package.dependencies]
+absl-py = "*"
+numpy = ">=1.8.0"
+setuptools = "!=50.0.0"
+
[[package]]
name = "lazy-loader"
version = "0.3"
@@ -1076,6 +1099,99 @@ files = [
{file = "llvmlite-0.42.0.tar.gz", hash = "sha256:f92b09243c0cc3f457da8b983f67bd8e1295d0f5b3746c7a1861d7a99403854a"},
]
+[[package]]
+name = "lxml"
+version = "5.1.0"
+description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"},
+ {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"},
+ {file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"},
+ {file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"},
+ {file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2befa20a13f1a75c751f47e00929fb3433d67eb9923c2c0b364de449121f447c"},
+ {file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22b7ee4c35f374e2c20337a95502057964d7e35b996b1c667b5c65c567d2252a"},
+ {file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:bf8443781533b8d37b295016a4b53c1494fa9a03573c09ca5104550c138d5c05"},
+ {file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"},
+ {file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"},
+ {file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"},
+ {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:14deca1460b4b0f6b01f1ddc9557704e8b365f55c63070463f6c18619ebf964f"},
+ {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"},
+ {file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"},
+ {file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"},
+ {file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af8920ce4a55ff41167ddbc20077f5698c2e710ad3353d32a07d3264f3a2021e"},
+ {file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7cfced4a069003d8913408e10ca8ed092c49a7f6cefee9bb74b6b3e860683b45"},
+ {file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9e5ac3437746189a9b4121db2a7b86056ac8786b12e88838696899328fc44bb2"},
+ {file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"},
+ {file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"},
+ {file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"},
+ {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9aa543980ab1fbf1720969af1d99095a548ea42e00361e727c58a40832439114"},
+ {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"},
+ {file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"},
+ {file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"},
+ {file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16dd953fb719f0ffc5bc067428fc9e88f599e15723a85618c45847c96f11f431"},
+ {file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16018f7099245157564d7148165132c70adb272fb5a17c048ba70d9cc542a1a1"},
+ {file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:82cd34f1081ae4ea2ede3d52f71b7be313756e99b4b5f829f89b12da552d3aa3"},
+ {file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:19a1bc898ae9f06bccb7c3e1dfd73897ecbbd2c96afe9095a6026016e5ca97b8"},
+ {file = "lxml-5.1.0-cp312-cp312-win32.whl", hash = "sha256:13521a321a25c641b9ea127ef478b580b5ec82aa2e9fc076c86169d161798b01"},
+ {file = "lxml-5.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:1ad17c20e3666c035db502c78b86e58ff6b5991906e55bdbef94977700c72623"},
+ {file = "lxml-5.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:24ef5a4631c0b6cceaf2dbca21687e29725b7c4e171f33a8f8ce23c12558ded1"},
+ {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d2900b7f5318bc7ad8631d3d40190b95ef2aa8cc59473b73b294e4a55e9f30f"},
+ {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:601f4a75797d7a770daed8b42b97cd1bb1ba18bd51a9382077a6a247a12aa38d"},
+ {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4b68c961b5cc402cbd99cca5eb2547e46ce77260eb705f4d117fd9c3f932b95"},
+ {file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:afd825e30f8d1f521713a5669b63657bcfe5980a916c95855060048b88e1adb7"},
+ {file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:262bc5f512a66b527d026518507e78c2f9c2bd9eb5c8aeeb9f0eb43fcb69dc67"},
+ {file = "lxml-5.1.0-cp36-cp36m-win32.whl", hash = "sha256:e856c1c7255c739434489ec9c8aa9cdf5179785d10ff20add308b5d673bed5cd"},
+ {file = "lxml-5.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:c7257171bb8d4432fe9d6fdde4d55fdbe663a63636a17f7f9aaba9bcb3153ad7"},
+ {file = "lxml-5.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b9e240ae0ba96477682aa87899d94ddec1cc7926f9df29b1dd57b39e797d5ab5"},
+ {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a96f02ba1bcd330807fc060ed91d1f7a20853da6dd449e5da4b09bfcc08fdcf5"},
+ {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e3898ae2b58eeafedfe99e542a17859017d72d7f6a63de0f04f99c2cb125936"},
+ {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61c5a7edbd7c695e54fca029ceb351fc45cd8860119a0f83e48be44e1c464862"},
+ {file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3aeca824b38ca78d9ee2ab82bd9883083d0492d9d17df065ba3b94e88e4d7ee6"},
+ {file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"},
+ {file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"},
+ {file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"},
+ {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae15347a88cf8af0949a9872b57a320d2605ae069bcdf047677318bc0bba45b1"},
+ {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"},
+ {file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"},
+ {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"},
+ {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"},
+ {file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:98f3f020a2b736566c707c8e034945c02aa94e124c24f77ca097c446f81b01f1"},
+ {file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"},
+ {file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"},
+ {file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"},
+ {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:52427a7eadc98f9e62cb1368a5079ae826f94f05755d2d567d93ee1bc3ceb354"},
+ {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"},
+ {file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"},
+ {file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"},
+ {file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f8b0c78e7aac24979ef09b7f50da871c2de2def043d468c4b41f512d831e912"},
+ {file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9bcf86dfc8ff3e992fed847c077bd875d9e0ba2fa25d859c3a0f0f76f07f0c8d"},
+ {file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:49a9b4af45e8b925e1cd6f3b15bbba2c81e7dba6dce170c677c9cda547411e14"},
+ {file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:280f3edf15c2a967d923bcfb1f8f15337ad36f93525828b40a0f9d6c2ad24890"},
+ {file = "lxml-5.1.0-cp39-cp39-win32.whl", hash = "sha256:ed7326563024b6e91fef6b6c7a1a2ff0a71b97793ac33dbbcf38f6005e51ff6e"},
+ {file = "lxml-5.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:8d7b4beebb178e9183138f552238f7e6613162a42164233e2bda00cb3afac58f"},
+ {file = "lxml-5.1.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9bd0ae7cc2b85320abd5e0abad5ccee5564ed5f0cc90245d2f9a8ef330a8deae"},
+ {file = "lxml-5.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8c1d679df4361408b628f42b26a5d62bd3e9ba7f0c0e7969f925021554755aa"},
+ {file = "lxml-5.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2ad3a8ce9e8a767131061a22cd28fdffa3cd2dc193f399ff7b81777f3520e372"},
+ {file = "lxml-5.1.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:304128394c9c22b6569eba2a6d98392b56fbdfbad58f83ea702530be80d0f9df"},
+ {file = "lxml-5.1.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d74fcaf87132ffc0447b3c685a9f862ffb5b43e70ea6beec2fb8057d5d2a1fea"},
+ {file = "lxml-5.1.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:8cf5877f7ed384dabfdcc37922c3191bf27e55b498fecece9fd5c2c7aaa34c33"},
+ {file = "lxml-5.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:877efb968c3d7eb2dad540b6cabf2f1d3c0fbf4b2d309a3c141f79c7e0061324"},
+ {file = "lxml-5.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f14a4fb1c1c402a22e6a341a24c1341b4a3def81b41cd354386dcb795f83897"},
+ {file = "lxml-5.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:25663d6e99659544ee8fe1b89b1a8c0aaa5e34b103fab124b17fa958c4a324a6"},
+ {file = "lxml-5.1.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8b9f19df998761babaa7f09e6bc169294eefafd6149aaa272081cbddc7ba4ca3"},
+ {file = "lxml-5.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e53d7e6a98b64fe54775d23a7c669763451340c3d44ad5e3a3b48a1efbdc96f"},
+ {file = "lxml-5.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c3cd1fc1dc7c376c54440aeaaa0dcc803d2126732ff5c6b68ccd619f2e64be4f"},
+ {file = "lxml-5.1.0.tar.gz", hash = "sha256:3eea6ed6e6c918e468e693c41ef07f3c3acc310b70ddd9cc72d9ef84bc9564ca"},
+]
+
+[package.extras]
+cssselect = ["cssselect (>=0.7)"]
+html5 = ["html5lib"]
+htmlsoup = ["BeautifulSoup4"]
+source = ["Cython (>=3.0.7)"]
+
[[package]]
name = "markupsafe"
version = "2.1.5"
@@ -1188,42 +1304,40 @@ tests = ["pytest (>=4.6)"]
[[package]]
name = "mujoco"
-version = "3.1.2"
+version = "2.3.7"
description = "MuJoCo Physics Simulator"
optional = false
python-versions = ">=3.8"
files = [
- {file = "mujoco-3.1.2-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:fe6b3542695a5363f348ee45625b3492734f29cdc9f493ca25eae719f974370e"},
- {file = "mujoco-3.1.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f07e2d1f01f1401f1a503187016f8c017d9402618c659e1482243640a1e11288"},
- {file = "mujoco-3.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:93863eccc9d77d96ce62dda2a6f61cbd880379e8d774f802568d64b9613fce39"},
- {file = "mujoco-3.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3586c642390c16fef58b01a86071cec6814c471586e2f4115c3733c4aec64fb7"},
- {file = "mujoco-3.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0da77394c664945b78f199c627b609fe091ec0c4641b9d8f713637344a17821a"},
- {file = "mujoco-3.1.2-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:b6f12904d0478c191e4770ecf9006e20953f0488a2411a8ddc62592721c136dc"},
- {file = "mujoco-3.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f69b8d42b50c10f8d12df4948fc9d4dd6706841e7b163c1d7ce83448965acb1c"},
- {file = "mujoco-3.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10119e39b1f45fb76b18bea242fea1d6ccf4b2285f8bd5e2cb1e2cbdeb69bdcd"},
- {file = "mujoco-3.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a65868506dd45dddfe7be84857e57b49bc102334fc0439aa848a4d4d285d89b"},
- {file = "mujoco-3.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:92bc73972e39539f23a05bb411c45f9be17191fe01171ac15ffafed381ee4366"},
- {file = "mujoco-3.1.2-cp312-cp312-macosx_10_16_x86_64.whl", hash = "sha256:835d6b64ca4dc2f6a83291275fd48bd83edc888039d247958bf5b2c759db4340"},
- {file = "mujoco-3.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ce94ca3cf14fc519981674c5b85f1055356dcdcd63bbc0ec6c340084438f27f"},
- {file = "mujoco-3.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:250d9de4bd0d31fa4165faf01a1f838c429434f1263faacd95b977580f24eae7"},
- {file = "mujoco-3.1.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ea009d10bbf0aba9bc835f051d25f07a2c3edbaa06627ac2348766a1f3760b9"},
- {file = "mujoco-3.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:a0460d2ebdad4926f48b8c774da473e011c3b3afd0ccb6b6be1087b788c34011"},
- {file = "mujoco-3.1.2-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:4ca7cae89e258a338e02229edcf8f177b459ac5e9f859ffffa07fc2c9fcfb6aa"},
- {file = "mujoco-3.1.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:33b4fe9b5f891b29ef0fc2b0b975bc3a8a4b87774eecaf4364a83ddc6a7762ba"},
- {file = "mujoco-3.1.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ed230980f33bafaf1fa8b32ef25b82b069a245de15ee6ce7127e7e054cfad16"},
- {file = "mujoco-3.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41cc610ac40f325c9d49d9885ac6cb61822ed938f6c23cb183b261a7a28472ca"},
- {file = "mujoco-3.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:90a172b904a6ca8e6a1be80ab7c393aaff7592843a2a6853a4f97a9204031c41"},
- {file = "mujoco-3.1.2-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:93201291a0c5b573b4cbb19a6b08c99673f9fba167f402174eae5ffa23066d24"},
- {file = "mujoco-3.1.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0398985bb28c2686cdeeaf4ded46e602a49ec12115ac77474144ca940e5261c5"},
- {file = "mujoco-3.1.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2e76b5cb07ab3088c81966ac774d573df027fa5f4e78c20953a547528a2a698"},
- {file = "mujoco-3.1.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd5c3f4ae858e812cb3f03332693bcdc343b2bce55b164523acf52dea2736c9e"},
- {file = "mujoco-3.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:ca25ff2646b06609526ef8681c0e123cd854a53c9ff23cb91dd5058a2794dab4"},
- {file = "mujoco-3.1.2.tar.gz", hash = "sha256:53530bc1a91903f3fd4b1e99818cc38fbd9911700db29b2c9fc839f23bfacbb8"},
+ {file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"},
+ {file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a934315f858a4e0c4b90a682fde519471cfdd7baa64435179da8cd20d4ae3f99"},
+ {file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:36513024330f88b5f9a43558efef5692b33599bffd5141029b690a27918ffcbe"},
+ {file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d4eede8ba8210fbd3d3cd1dbf69e24dd1541aa74c5af5b8adbbbf65504b6dba"},
+ {file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab85fafc9d5a091c712947573b7e694512d283876bf7f33ae3f8daad3a20c0db"},
+ {file = "mujoco-2.3.7-cp310-cp310-win_amd64.whl", hash = "sha256:f8b7e13fef8c813d91b78f975ed0815157692777907ffa4b4be53a4edb75019b"},
+ {file = "mujoco-2.3.7-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:779520216f72a8e370e3f0cdd71b45c3b7384c63331a3189194c930a3e7cff5c"},
+ {file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9d4018053879016282d27ab7a91e292c72d44efb5a88553feacfe5b843dde103"},
+ {file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:3149b16b8122ee62642474bfd2871064e8edc40235471cf5d84be3569afc0312"},
+ {file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c08660a8d52ef3efde76095f0991e807703a950c1e882d2bcd984b9a846626f7"},
+ {file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:426af8965f8636d94a0f75740c3024a62b3e585020ee817ef5208ec844a1ad94"},
+ {file = "mujoco-2.3.7-cp311-cp311-win_amd64.whl", hash = "sha256:215415a8e98a4b50625beae859079d5e0810b2039e50420f0ba81763c34abb59"},
+ {file = "mujoco-2.3.7-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:8b78d14f4c60cea3c58e046bd4de453fb5b9b33aca6a25fc91d39a53f3a5342a"},
+ {file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5c6f5a51d6f537a4bf294cf73816f3a6384573f8f10a5452b044df2771412a96"},
+ {file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:ea8911e6047f92d7d775701f37e4c093971b6def3160f01d0b6926e29a7e962e"},
+ {file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7473a3de4dd1a8762d569ffb139196b4c5e7eca27d256df97b6cd4c66d2a09b2"},
+ {file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40e7e2d8f93d2495ec74efec84e5118ecc6e1d85157a844789c73c9ac9a4e28e"},
+ {file = "mujoco-2.3.7-cp38-cp38-win_amd64.whl", hash = "sha256:720bc228a2023b3b0ed6af78f5b0f8ea36867be321d473321555c57dbf6e4e5b"},
+ {file = "mujoco-2.3.7-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:855e79686366442aa410246043b44f7d842d3900d68fe7e37feb42147db9d707"},
+ {file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:98947f4a742d34d36f3c3f83e9167025bb0414bbaa4bd859b0673bdab9959963"},
+ {file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:d42818f2ee5d1632dbce31d136ed5ff868db54b04e4e9aca0c5a3ac329f8a90f"},
+ {file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9237e1ba14bced9449c31199e6d5be49547f3a4c99bc83b196af7ca45fd73b83"},
+ {file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b728ea638245b150e2650c5433e6952e0ed3798c63e47e264574270caea2a3"},
+ {file = "mujoco-2.3.7-cp39-cp39-win_amd64.whl", hash = "sha256:9c721a5042b99d948d5f0296a534bcce3f142c777c4d7642f503a539513f3912"},
+ {file = "mujoco-2.3.7.tar.gz", hash = "sha256:422041f1ce37c6d151fbced1048df626837e94fe3cd9f813585907046336a7d0"},
]
[package.dependencies]
absl-py = "*"
-etils = {version = "*", extras = ["epath"]}
glfw = "*"
numpy = "*"
pyopengl = "*"
@@ -1519,13 +1633,13 @@ files = [
[[package]]
name = "nvidia-nvjitlink-cu12"
-version = "12.3.101"
+version = "12.4.99"
description = "Nvidia JIT LTO Library"
optional = false
python-versions = ">=3"
files = [
- {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"},
- {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"},
+ {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"},
+ {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"},
]
[[package]]
@@ -1580,13 +1694,13 @@ numpy = [
[[package]]
name = "packaging"
-version = "23.2"
+version = "24.0"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.7"
files = [
- {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"},
- {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
+ {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"},
+ {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
]
[[package]]
@@ -2016,6 +2130,20 @@ files = [
{file = "PyOpenGL-3.1.7.tar.gz", hash = "sha256:eef31a3888e6984fd4d8e6c9961b184c9813ca82604d37fe3da80eb000a76c86"},
]
+[[package]]
+name = "pyparsing"
+version = "3.1.2"
+description = "pyparsing module - Classes and methods to define and execute parsing grammars"
+optional = false
+python-versions = ">=3.6.8"
+files = [
+ {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"},
+ {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"},
+]
+
+[package.extras]
+diagrams = ["jinja2", "railroad-diagrams"]
+
[[package]]
name = "pysocks"
version = "1.7.1"
@@ -2030,13 +2158,13 @@ files = [
[[package]]
name = "pytest"
-version = "8.1.0"
+version = "8.1.1"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "pytest-8.1.0-py3-none-any.whl", hash = "sha256:ee32db7af8de4629a455806befa90559f307424c07b8413ccfc30bf5b221dd7e"},
- {file = "pytest-8.1.0.tar.gz", hash = "sha256:f8fa04ab8f98d185113ae60ea6d79c22f8143b14bc1caeced44a0ab844928323"},
+ {file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"},
+ {file = "pytest-8.1.1.tar.gz", hash = "sha256:ac978141a75948948817d360297b7aae0fcb9d6ff6bc9ec6d514b85d5a65c044"},
]
[package.dependencies]
@@ -2052,13 +2180,13 @@ testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygm
[[package]]
name = "python-dateutil"
-version = "2.8.2"
+version = "2.9.0.post0"
description = "Extensions to the standard Python datetime module"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
files = [
- {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
- {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
+ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
+ {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
]
[package.dependencies]
@@ -2483,13 +2611,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov",
[[package]]
name = "sentry-sdk"
-version = "1.40.6"
+version = "1.41.0"
description = "Python client for Sentry (https://sentry.io)"
optional = false
python-versions = "*"
files = [
- {file = "sentry-sdk-1.40.6.tar.gz", hash = "sha256:f143f3fb4bb57c90abef6e2ad06b5f6f02b2ca13e4060ec5c0549c7a9ccce3fa"},
- {file = "sentry_sdk-1.40.6-py2.py3-none-any.whl", hash = "sha256:becda09660df63e55f307570e9817c664392655a7328bbc414b507e9cb874c67"},
+ {file = "sentry-sdk-1.41.0.tar.gz", hash = "sha256:4f2d6c43c07925d8cd10dfbd0970ea7cb784f70e79523cca9dbcd72df38e5a46"},
+ {file = "sentry_sdk-1.41.0-py2.py3-none-any.whl", hash = "sha256:be4f8f4b29a80b6a3b71f0f31487beb9e296391da20af8504498a328befed53f"},
]
[package.dependencies]
@@ -2515,7 +2643,7 @@ huey = ["huey (>=2)"]
loguru = ["loguru (>=0.5)"]
opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"]
-pure-eval = ["asttokens", "executing", "pure_eval"]
+pure-eval = ["asttokens", "executing", "pure-eval"]
pymongo = ["pymongo (>=3.1)"]
pyspark = ["pyspark (>=2.4.4)"]
quart = ["blinker (>=1.1)", "quart (>=0.16.1)"]
@@ -2769,7 +2897,7 @@ tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures
type = "git"
url = "https://github.com/pytorch/tensordict"
reference = "HEAD"
-resolved_reference = "551331d83e2979dd4505db1f49895740e6e5c95f"
+resolved_reference = "ed22554d6860731610df784b2f5d09f31d3dbc7a"
[[package]]
name = "termcolor"
@@ -2888,13 +3016,13 @@ tensordict = ">=0.4.0"
torch = ">=2.1.0"
[package.extras]
-all = ["ale-py", "atari-py", "dm-control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface-hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
+all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"]
checkpointing = ["torchsnapshot"]
-dm-control = ["dm-control"]
+dm-control = ["dm_control"]
gym-continuous = ["gymnasium", "mujoco"]
marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"]
-offline-data = ["h5py", "huggingface-hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
+offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
rendering = ["moviepy"]
tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"]
utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"]
@@ -3051,13 +3179,13 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess
[[package]]
name = "wandb"
-version = "0.16.3"
+version = "0.16.4"
description = "A CLI and library for interacting with the Weights & Biases API."
optional = false
python-versions = ">=3.7"
files = [
- {file = "wandb-0.16.3-py3-none-any.whl", hash = "sha256:b8907ddd775c27dc6c12687386a86b5d6acf291060f9ae680bbc61cc8fc03237"},
- {file = "wandb-0.16.3.tar.gz", hash = "sha256:d789acda32053b18b7a160d0595837e45a3c8a79d25e1fe1f051875303f480ec"},
+ {file = "wandb-0.16.4-py3-none-any.whl", hash = "sha256:bb9eb5aa2c2c85e11c76040c4271366f54d4975167aa6320ba86c3f2d97fe5fa"},
+ {file = "wandb-0.16.4.tar.gz", hash = "sha256:8752c67d1347a4c29777e64dc1e1a742a66c5ecde03aebadf2b0d62183fa307c"},
]
[package.dependencies]
@@ -3078,8 +3206,9 @@ async = ["httpx (>=0.23.0)"]
aws = ["boto3"]
azure = ["azure-identity", "azure-storage-blob"]
gcp = ["google-cloud-storage"]
+importers = ["filelock", "mlflow", "polars", "rich", "tenacity"]
kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"]
-launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "typing-extensions"]
+launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "tomli", "typing-extensions"]
media = ["bokeh", "moviepy", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit-pypi", "soundfile"]
models = ["cloudpickle"]
perf = ["orjson"]
@@ -3088,13 +3217,13 @@ sweeps = ["sweeps (>=0.2.0)"]
[[package]]
name = "zarr"
-version = "2.17.0"
+version = "2.17.1"
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
optional = false
python-versions = ">=3.9"
files = [
- {file = "zarr-2.17.0-py3-none-any.whl", hash = "sha256:d287cb61019c4a0a0f386f76eeaa7f0b1160b1cb90cf96173a4b6cbc135df6e1"},
- {file = "zarr-2.17.0.tar.gz", hash = "sha256:6390a2b8af31babaab4c963efc45bf1da7f9500c9aafac193f84cf019a7c66b0"},
+ {file = "zarr-2.17.1-py3-none-any.whl", hash = "sha256:e25df2741a6e92645f3890f30f3136d5b57a0f8f831094b024bbcab5f2797bc7"},
+ {file = "zarr-2.17.1.tar.gz", hash = "sha256:564b3aa072122546fe69a0fa21736f466b20fad41754334b62619f088ce46261"},
]
[package.dependencies]
@@ -3125,4 +3254,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "c4d83579aed1c8c2e54cad7c8ec81b95a09ab8faff74fc9a4cb20bd00e4ddec6"
+content-hash = "3d82309a7b2388d774b56ceb6f6906ef0732d8cedda0d76cc84a30e239949be8"
diff --git a/pyproject.toml b/pyproject.toml
index 398f63ef..85af7f82 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -42,13 +42,14 @@ mpmath = "^1.3.0"
torch = "^2.2.1"
tensordict = {git = "https://github.com/pytorch/tensordict"}
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
-mujoco = "^3.1.2"
+mujoco = "2.3.7"
mujoco-py = "^2.1.2.14"
gym = "^0.26.2"
opencv-python = "^4.9.0.80"
diffusers = "^0.26.3"
torchvision = "^0.17.1"
h5py = "^3.10.0"
+dm-control = "1.0.14"
[tool.poetry.group.dev.dependencies]
diff --git a/sbatch.sh b/sbatch.sh
index da52c472..cb5b285a 100644
--- a/sbatch.sh
+++ b/sbatch.sh
@@ -17,6 +17,7 @@ apptainer exec --nv \
~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL
source ~/.bashrc
-conda activate fowm
+#conda activate fowm
+conda activate lerobot
srun $CMD
diff --git a/tests/data/aloha_sim_insertion_human/action.memmap b/tests/data/aloha_sim_insertion_human/action.memmap
new file mode 100644
index 00000000..f64b2989
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_human/action.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d789deddb081a9f4b626342391de8f48949d38fb5fdead87b5c0737b46c0877a
+size 2800
diff --git a/tests/data/aloha_sim_insertion_human/episode.memmap b/tests/data/aloha_sim_insertion_human/episode.memmap
new file mode 100644
index 00000000..af9fb07f
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_human/episode.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
+size 400
diff --git a/tests/data/aloha_sim_insertion_human/frame_id.memmap b/tests/data/aloha_sim_insertion_human/frame_id.memmap
new file mode 100644
index 00000000..dc2f585c
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_human/frame_id.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
+size 400
diff --git a/tests/data/aloha_sim_insertion_human/meta.json b/tests/data/aloha_sim_insertion_human/meta.json
new file mode 100644
index 00000000..2a0cf0a2
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_human/meta.json
@@ -0,0 +1 @@
+{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_human/next/done.memmap b/tests/data/aloha_sim_insertion_human/next/done.memmap
new file mode 100644
index 00000000..44fd709f
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_human/next/done.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
+size 50
diff --git a/tests/data/aloha_sim_insertion_human/next/meta.json b/tests/data/aloha_sim_insertion_human/next/meta.json
new file mode 100644
index 00000000..3bfa9bd7
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_human/next/meta.json
@@ -0,0 +1 @@
+{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_human/next/observation/image/meta.json b/tests/data/aloha_sim_insertion_human/next/observation/image/meta.json
new file mode 100644
index 00000000..cb29a5ab
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_human/next/observation/image/meta.json
@@ -0,0 +1 @@
+{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_human/next/observation/image/top.memmap b/tests/data/aloha_sim_insertion_human/next/observation/image/top.memmap
new file mode 100644
index 00000000..d3d8bd1c
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_human/next/observation/image/top.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5c632e3cb06be729e5d673e3ecca1d6f6527b0f48cfe3dc03d7eea4f9eb3bbd7
+size 46080000
diff --git a/tests/data/aloha_sim_insertion_human/next/observation/meta.json b/tests/data/aloha_sim_insertion_human/next/observation/meta.json
new file mode 100644
index 00000000..65ce1ca2
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_human/next/observation/meta.json
@@ -0,0 +1 @@
+{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_human/next/observation/state.memmap b/tests/data/aloha_sim_insertion_human/next/observation/state.memmap
new file mode 100644
index 00000000..1f087a60
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_human/next/observation/state.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e231f2e07e1cd030137ea2e938b570b112db2c694c6d21b37ceb8f8559e19088
+size 2800
diff --git a/tests/data/aloha_sim_insertion_human/observation/image/meta.json b/tests/data/aloha_sim_insertion_human/observation/image/meta.json
new file mode 100644
index 00000000..cb29a5ab
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_human/observation/image/meta.json
@@ -0,0 +1 @@
+{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_human/observation/image/top.memmap b/tests/data/aloha_sim_insertion_human/observation/image/top.memmap
new file mode 100644
index 00000000..00c0b783
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_human/observation/image/top.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1ba64c89f4fcf9135fe34c26abf582dd5f0d573506db5c96af3ffe40a52c818
+size 46080000
diff --git a/tests/data/aloha_sim_insertion_human/observation/meta.json b/tests/data/aloha_sim_insertion_human/observation/meta.json
new file mode 100644
index 00000000..65ce1ca2
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_human/observation/meta.json
@@ -0,0 +1 @@
+{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_human/observation/state.memmap b/tests/data/aloha_sim_insertion_human/observation/state.memmap
new file mode 100644
index 00000000..a1131179
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_human/observation/state.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:85405686bc065c6ab6c915907920a0391a57cf097b74de058a8c30be0548ade5
+size 2800
diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth
new file mode 100644
index 00000000..869d26cd
Binary files /dev/null and b/tests/data/aloha_sim_insertion_human/stats.pth differ
diff --git a/tests/data/aloha_sim_insertion_scripted/action.memmap b/tests/data/aloha_sim_insertion_scripted/action.memmap
new file mode 100644
index 00000000..e4068b75
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_scripted/action.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1f5fe053b760e8471885b82c10f4a6ea40874098036337ae5cc300c4775546be
+size 2800
diff --git a/tests/data/aloha_sim_insertion_scripted/episode.memmap b/tests/data/aloha_sim_insertion_scripted/episode.memmap
new file mode 100644
index 00000000..af9fb07f
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_scripted/episode.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
+size 400
diff --git a/tests/data/aloha_sim_insertion_scripted/frame_id.memmap b/tests/data/aloha_sim_insertion_scripted/frame_id.memmap
new file mode 100644
index 00000000..dc2f585c
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_scripted/frame_id.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
+size 400
diff --git a/tests/data/aloha_sim_insertion_scripted/meta.json b/tests/data/aloha_sim_insertion_scripted/meta.json
new file mode 100644
index 00000000..2a0cf0a2
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_scripted/meta.json
@@ -0,0 +1 @@
+{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_scripted/next/done.memmap b/tests/data/aloha_sim_insertion_scripted/next/done.memmap
new file mode 100644
index 00000000..44fd709f
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_scripted/next/done.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
+size 50
diff --git a/tests/data/aloha_sim_insertion_scripted/next/meta.json b/tests/data/aloha_sim_insertion_scripted/next/meta.json
new file mode 100644
index 00000000..3bfa9bd7
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_scripted/next/meta.json
@@ -0,0 +1 @@
+{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/image/meta.json b/tests/data/aloha_sim_insertion_scripted/next/observation/image/meta.json
new file mode 100644
index 00000000..cb29a5ab
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_scripted/next/observation/image/meta.json
@@ -0,0 +1 @@
+{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/image/top.memmap b/tests/data/aloha_sim_insertion_scripted/next/observation/image/top.memmap
new file mode 100644
index 00000000..83911729
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_scripted/next/observation/image/top.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:daed2bb10498ba2557983d0d7e89399882fea7585e7ceff910e23c621bfdbf88
+size 46080000
diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/meta.json b/tests/data/aloha_sim_insertion_scripted/next/observation/meta.json
new file mode 100644
index 00000000..65ce1ca2
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_scripted/next/observation/meta.json
@@ -0,0 +1 @@
+{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_scripted/next/observation/state.memmap b/tests/data/aloha_sim_insertion_scripted/next/observation/state.memmap
new file mode 100644
index 00000000..aef69da0
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_scripted/next/observation/state.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bbad0302af70112ee312efe0eb0f44a2f1c8f6c5ef82ea4fb34625cdafbef057
+size 2800
diff --git a/tests/data/aloha_sim_insertion_scripted/observation/image/meta.json b/tests/data/aloha_sim_insertion_scripted/observation/image/meta.json
new file mode 100644
index 00000000..cb29a5ab
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_scripted/observation/image/meta.json
@@ -0,0 +1 @@
+{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_scripted/observation/image/top.memmap b/tests/data/aloha_sim_insertion_scripted/observation/image/top.memmap
new file mode 100644
index 00000000..f9f0a759
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_scripted/observation/image/top.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aba55ebb9dd004bf68444b9ebf024ed7713436099c06a0b8e541100ecbc69290
+size 46080000
diff --git a/tests/data/aloha_sim_insertion_scripted/observation/meta.json b/tests/data/aloha_sim_insertion_scripted/observation/meta.json
new file mode 100644
index 00000000..65ce1ca2
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_scripted/observation/meta.json
@@ -0,0 +1 @@
+{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_scripted/observation/state.memmap b/tests/data/aloha_sim_insertion_scripted/observation/state.memmap
new file mode 100644
index 00000000..91875055
--- /dev/null
+++ b/tests/data/aloha_sim_insertion_scripted/observation/state.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd4e7e14abf57561ca9839c910581266be90956e41bfb3bb21362ea0c321e77d
+size 2800
diff --git a/tests/data/aloha_sim_insertion_scripted/stats.pth b/tests/data/aloha_sim_insertion_scripted/stats.pth
new file mode 100644
index 00000000..0f3f9d2f
Binary files /dev/null and b/tests/data/aloha_sim_insertion_scripted/stats.pth differ
diff --git a/tests/data/aloha_sim_transfer_cube_human/action.memmap b/tests/data/aloha_sim_transfer_cube_human/action.memmap
new file mode 100644
index 00000000..9b4fef33
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_human/action.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:14fed0eed3d529a8ac0dd25a6d41585020772d02f9137fc9d604713b2f0f7076
+size 2800
diff --git a/tests/data/aloha_sim_transfer_cube_human/episode.memmap b/tests/data/aloha_sim_transfer_cube_human/episode.memmap
new file mode 100644
index 00000000..af9fb07f
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_human/episode.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
+size 400
diff --git a/tests/data/aloha_sim_transfer_cube_human/frame_id.memmap b/tests/data/aloha_sim_transfer_cube_human/frame_id.memmap
new file mode 100644
index 00000000..dc2f585c
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_human/frame_id.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
+size 400
diff --git a/tests/data/aloha_sim_transfer_cube_human/meta.json b/tests/data/aloha_sim_transfer_cube_human/meta.json
new file mode 100644
index 00000000..2a0cf0a2
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_human/meta.json
@@ -0,0 +1 @@
+{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_human/next/done.memmap b/tests/data/aloha_sim_transfer_cube_human/next/done.memmap
new file mode 100644
index 00000000..44fd709f
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_human/next/done.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
+size 50
diff --git a/tests/data/aloha_sim_transfer_cube_human/next/meta.json b/tests/data/aloha_sim_transfer_cube_human/next/meta.json
new file mode 100644
index 00000000..3bfa9bd7
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_human/next/meta.json
@@ -0,0 +1 @@
+{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_human/next/observation/image/meta.json
new file mode 100644
index 00000000..cb29a5ab
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_human/next/observation/image/meta.json
@@ -0,0 +1 @@
+{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_human/next/observation/image/top.memmap
new file mode 100644
index 00000000..cd2e7c06
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_human/next/observation/image/top.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2f713ea7fc19e592ea409a5e0bdfde403e5b86f834cbabe3463b791e8437fafc
+size 46080000
diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/meta.json b/tests/data/aloha_sim_transfer_cube_human/next/observation/meta.json
new file mode 100644
index 00000000..65ce1ca2
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_human/next/observation/meta.json
@@ -0,0 +1 @@
+{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_human/next/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_human/next/observation/state.memmap
new file mode 100644
index 00000000..37feaad6
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_human/next/observation/state.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8c103e2c9d63c9f7cf9645bd24d9a2c4e8e08825dc75e230ebc793b8f9c213b0
+size 2800
diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_human/observation/image/meta.json
new file mode 100644
index 00000000..cb29a5ab
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_human/observation/image/meta.json
@@ -0,0 +1 @@
+{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_human/observation/image/top.memmap
new file mode 100644
index 00000000..1188590c
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_human/observation/image/top.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7dbf4aa01b184d0eaa21ea999078d7cff86e1ca484a109614176fdc49f1ee05c
+size 46080000
diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/meta.json b/tests/data/aloha_sim_transfer_cube_human/observation/meta.json
new file mode 100644
index 00000000..65ce1ca2
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_human/observation/meta.json
@@ -0,0 +1 @@
+{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_human/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_human/observation/state.memmap
new file mode 100644
index 00000000..9ef4cfd6
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_human/observation/state.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4fa0b9c870d4615037b6fee9e9e85e54d84352e173f2c7c1035232272fe2a3dd
+size 2800
diff --git a/tests/data/aloha_sim_transfer_cube_human/stats.pth b/tests/data/aloha_sim_transfer_cube_human/stats.pth
new file mode 100644
index 00000000..0fa5667c
Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_human/stats.pth differ
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/action.memmap b/tests/data/aloha_sim_transfer_cube_scripted/action.memmap
new file mode 100644
index 00000000..8ac0c726
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_scripted/action.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c0e199a82e2b7462e84406dbced5448a99f1dad9ce172771dfc3feb4b8597115
+size 2800
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/episode.memmap b/tests/data/aloha_sim_transfer_cube_scripted/episode.memmap
new file mode 100644
index 00000000..af9fb07f
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_scripted/episode.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
+size 400
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/frame_id.memmap b/tests/data/aloha_sim_transfer_cube_scripted/frame_id.memmap
new file mode 100644
index 00000000..dc2f585c
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_scripted/frame_id.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
+size 400
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/meta.json
new file mode 100644
index 00000000..2a0cf0a2
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_scripted/meta.json
@@ -0,0 +1 @@
+{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/done.memmap b/tests/data/aloha_sim_transfer_cube_scripted/next/done.memmap
new file mode 100644
index 00000000..44fd709f
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_scripted/next/done.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
+size 50
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/next/meta.json
new file mode 100644
index 00000000..3bfa9bd7
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_scripted/next/meta.json
@@ -0,0 +1 @@
+{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/meta.json
new file mode 100644
index 00000000..cb29a5ab
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/meta.json
@@ -0,0 +1 @@
+{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/top.memmap
new file mode 100644
index 00000000..8e5f533e
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/image/top.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:30b44d38cc4d68e06c716a875d39cbdbeacbfdc1657d6366f58c279efd27c52b
+size 46080000
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/meta.json
new file mode 100644
index 00000000..65ce1ca2
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/meta.json
@@ -0,0 +1 @@
+{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/next/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/state.memmap
new file mode 100644
index 00000000..e88320d1
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_scripted/next/observation/state.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9f484e7ea4f5f612dd53ee2c0f7891b8f7b2168a54fc81941ac2f2447260c294
+size 2800
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/observation/image/meta.json
new file mode 100644
index 00000000..cb29a5ab
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_scripted/observation/image/meta.json
@@ -0,0 +1 @@
+{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_scripted/observation/image/top.memmap
new file mode 100644
index 00000000..d415da0a
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_scripted/observation/image/top.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09600206f56cc5b52dfb896204b0044c4e830da368da141d7bd10e52181f6835
+size 46080000
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/observation/meta.json
new file mode 100644
index 00000000..65ce1ca2
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_scripted/observation/meta.json
@@ -0,0 +1 @@
+{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_scripted/observation/state.memmap
new file mode 100644
index 00000000..be3436fb
--- /dev/null
+++ b/tests/data/aloha_sim_transfer_cube_scripted/observation/state.memmap
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:60dcb547cf9a6372b78a455217a2408b6bece4371fba1df2a302b334d45c42a8
+size 2800
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/stats.pth b/tests/data/aloha_sim_transfer_cube_scripted/stats.pth
new file mode 100644
index 00000000..b00756e6
Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_scripted/stats.pth differ
diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth
index c34e6a94..037e02f0 100644
Binary files a/tests/data/pusht/stats.pth and b/tests/data/pusht/stats.pth differ
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index e63ae2c1..b7d1e6f8 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -1,8 +1,9 @@
import pytest
+import torch
from lerobot.common.datasets.factory import make_offline_buffer
-from .utils import init_config
+from .utils import DEVICE, init_config
@pytest.mark.parametrize(
@@ -12,12 +13,18 @@ from .utils import init_config
# ("simxarm", "lift"),
("pusht", "pusht"),
# TODO(aliberts): add aloha when dataset is available on hub
- # ("aloha", "sim_insertion_human"),
- # ("aloha", "sim_insertion_scripted"),
- # ("aloha", "sim_transfer_cube_human"),
- # ("aloha", "sim_transfer_cube_scripted"),
+ ("aloha", "sim_insertion_human"),
+ ("aloha", "sim_insertion_scripted"),
+ ("aloha", "sim_transfer_cube_human"),
+ ("aloha", "sim_transfer_cube_scripted"),
],
)
def test_factory(env_name, dataset_id):
- cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}"])
+ cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"])
offline_buffer = make_offline_buffer(cfg)
+ for key in offline_buffer.image_keys:
+ img = offline_buffer[0].get(key)
+ assert img.dtype == torch.float32
+ # TODO(rcadene): we assume for now that image normalization takes place in the model
+ assert img.max() <= 1.0
+ assert img.min() >= 0.0
diff --git a/tests/test_envs.py b/tests/test_envs.py
index b51c441b..7776ba3c 100644
--- a/tests/test_envs.py
+++ b/tests/test_envs.py
@@ -1,12 +1,15 @@
+import os
import pytest
from tensordict import TensorDict
+import torch
from torchrl.envs.utils import check_env_specs, step_mdp
+from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.pusht.env import PushtEnv
from lerobot.common.envs.simxarm import SimxarmEnv
-from .utils import init_config
+from .utils import DEVICE, init_config
def print_spec_rollout(env):
@@ -83,9 +86,24 @@ def test_pusht(from_pixels, pixels_only):
[
# "simxarm",
"pusht",
+ "aloha",
],
)
def test_factory(env_name):
- cfg = init_config(overrides=[f"env={env_name}"])
+ cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"])
+
+ offline_buffer = make_offline_buffer(cfg)
+
env = make_env(cfg)
+ for key in offline_buffer.image_keys:
+ assert env.reset().get(key).dtype == torch.uint8
+ check_env_specs(env)
+
+ env = make_env(cfg, transform=offline_buffer.transform)
+ for key in offline_buffer.image_keys:
+ img = env.reset().get(key)
+ assert img.dtype == torch.float32
+ # TODO(rcadene): we assume for now that image normalization takes place in the model
+ assert img.max() <= 1.0
+ assert img.min() >= 0.0
check_env_specs(env)
diff --git a/tests/test_policies.py b/tests/test_policies.py
index 03f20bd0..f00429bc 100644
--- a/tests/test_policies.py
+++ b/tests/test_policies.py
@@ -2,7 +2,7 @@ import pytest
from lerobot.common.policies.factory import make_policy
-from .utils import init_config
+from .utils import DEVICE, init_config
@pytest.mark.parametrize(
@@ -19,6 +19,7 @@ def test_factory(env_name, policy_name):
overrides=[
f"env={env_name}",
f"policy={policy_name}",
+ f"device={DEVICE}",
]
)
policy = make_policy(cfg)
diff --git a/tests/utils.py b/tests/utils.py
index 40dc6de0..55709330 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1,8 +1,10 @@
+import os
import hydra
from hydra import compose, initialize
CONFIG_PATH = "../lerobot/configs"
+DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda")
def init_config(config_name="default", overrides=None):
hydra.core.global_hydra.GlobalHydra.instance().clear()