diff --git a/.github/poetry/cpu/poetry.lock b/.github/poetry/cpu/poetry.lock
index ba820f34..98a3d58d 100644
--- a/.github/poetry/cpu/poetry.lock
+++ b/.github/poetry/cpu/poetry.lock
@@ -517,21 +517,11 @@ files = [
{file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"},
]
-[[package]]
-name = "dm"
-version = "1.3"
-description = "Dict to Data mapper"
-optional = false
-python-versions = "*"
-files = [
- {file = "dm-1.3.tar.gz", hash = "sha256:ce77537bf346b5d8c0dc0b5d679cfc4a946faadcd5315e6c80ef6f3af824130d"},
-]
-
[[package]]
name = "dm-control"
version = "1.0.14"
description = "Continuous control environments and MuJoCo Python bindings."
-optional = false
+optional = true
python-versions = ">=3.8"
files = [
{file = "dm_control-1.0.14-py3-none-any.whl", hash = "sha256:883c63244a7ebf598700a97564ed19fffd3479ca79efd090aed881609cdb9fc6"},
@@ -562,7 +552,7 @@ hdf5 = ["h5py"]
name = "dm-env"
version = "1.6"
description = "A Python interface for Reinforcement Learning environments."
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "dm-env-1.6.tar.gz", hash = "sha256:a436eb1c654c39e0c986a516cee218bea7140b510fceff63f97eb4fcff3d93de"},
@@ -578,7 +568,7 @@ numpy = "*"
name = "dm-tree"
version = "0.1.8"
description = "Tree is a library for working with nested data structures."
-optional = false
+optional = true
python-versions = "*"
files = [
{file = "dm-tree-0.1.8.tar.gz", hash = "sha256:0fcaabbb14e7980377439e7140bd05552739ca5e515ecb3119f234acee4b9430"},
@@ -806,7 +796,7 @@ test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre
name = "glfw"
version = "2.7.0"
description = "A ctypes-based wrapper for GLFW3."
-optional = false
+optional = true
python-versions = "*"
files = [
{file = "glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-macosx_10_6_intel.whl", hash = "sha256:bd82849edcceda4e262bd1227afaa74b94f9f0731c1197863cd25c15bfc613fc"},
@@ -889,6 +879,69 @@ files = [
[package.extras]
protobuf = ["grpcio-tools (>=1.62.1)"]
+[[package]]
+name = "gym-aloha"
+version = "0.1.0"
+description = "A gym environment for ALOHA"
+optional = true
+python-versions = "^3.10"
+files = []
+develop = false
+
+[package.dependencies]
+dm-control = "1.0.14"
+gymnasium = "^0.29.1"
+mujoco = "^2.3.7"
+
+[package.source]
+type = "git"
+url = "git@github.com:huggingface/gym-aloha.git"
+reference = "HEAD"
+resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f"
+
+[[package]]
+name = "gym-pusht"
+version = "0.1.0"
+description = "A gymnasium environment for PushT."
+optional = true
+python-versions = "^3.10"
+files = []
+develop = false
+
+[package.dependencies]
+gymnasium = "^0.29.1"
+opencv-python = "^4.9.0.80"
+pygame = "^2.5.2"
+pymunk = "^6.6.0"
+scikit-image = "^0.22.0"
+shapely = "^2.0.3"
+
+[package.source]
+type = "git"
+url = "git@github.com:huggingface/gym-pusht.git"
+reference = "HEAD"
+resolved_reference = "6c9893504f670ff069d0f759a733e971ea1efdbf"
+
+[[package]]
+name = "gym-xarm"
+version = "0.1.0"
+description = "A gym environment for xArm"
+optional = true
+python-versions = "^3.10"
+files = []
+develop = false
+
+[package.dependencies]
+gymnasium = "^0.29.1"
+gymnasium-robotics = "^1.2.4"
+mujoco = "^2.3.7"
+
+[package.source]
+type = "git"
+url = "git@github.com:huggingface/gym-xarm.git"
+reference = "HEAD"
+resolved_reference = "6a88f7d63833705dfbec4b997bf36cac6b4a448c"
+
[[package]]
name = "gymnasium"
version = "0.29.1"
@@ -923,7 +976,7 @@ toy-text = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"]
name = "gymnasium-robotics"
version = "1.2.4"
description = "Robotics environments for the Gymnasium repo."
-optional = false
+optional = true
python-versions = ">=3.8"
files = [
{file = "gymnasium-robotics-1.2.4.tar.gz", hash = "sha256:d304192b066f8b800599dfbe3d9d90bba9b761ee884472bdc4d05968a8bc61cb"},
@@ -1155,7 +1208,7 @@ i18n = ["Babel (>=2.7)"]
name = "labmaze"
version = "1.0.6"
description = "LabMaze: DeepMind Lab's text maze generator."
-optional = false
+optional = true
python-versions = "*"
files = [
{file = "labmaze-1.0.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b2ddef976dfd8d992b19cfa6c633f2eba7576d759c2082da534e3f727479a84a"},
@@ -1199,7 +1252,7 @@ setuptools = "!=50.0.0"
name = "lazy-loader"
version = "0.3"
description = "lazy_loader"
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "lazy_loader-0.3-py3-none-any.whl", hash = "sha256:1e9e76ee8631e264c62ce10006718e80b2cfc74340d17d1031e0f84af7478554"},
@@ -1244,7 +1297,7 @@ files = [
name = "lxml"
version = "5.1.0"
description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API."
-optional = false
+optional = true
python-versions = ">=3.6"
files = [
{file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"},
@@ -1462,7 +1515,7 @@ tests = ["pytest (>=4.6)"]
name = "mujoco"
version = "2.3.7"
description = "MuJoCo Physics Simulator"
-optional = false
+optional = true
python-versions = ">=3.8"
files = [
{file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"},
@@ -1776,7 +1829,7 @@ xml = ["lxml (>=4.9.2)"]
name = "pettingzoo"
version = "1.24.3"
description = "Gymnasium for multi-agent reinforcement learning."
-optional = false
+optional = true
python-versions = ">=3.8"
files = [
{file = "pettingzoo-1.24.3-py3-none-any.whl", hash = "sha256:23ed90517d2e8a7098bdaf5e31234b3a7f7b73ca578d70d1ca7b9d0cb0e37982"},
@@ -2144,7 +2197,7 @@ dev = ["aafigure", "matplotlib", "pygame", "pyglet (<2.0.0)", "sphinx", "wheel"]
name = "pyopengl"
version = "3.1.7"
description = "Standard OpenGL bindings for Python"
-optional = false
+optional = true
python-versions = "*"
files = [
{file = "PyOpenGL-3.1.7-py3-none-any.whl", hash = "sha256:a6ab19cf290df6101aaf7470843a9c46207789855746399d0af92521a0a92b7a"},
@@ -2155,7 +2208,7 @@ files = [
name = "pyparsing"
version = "3.1.2"
description = "pyparsing module - Classes and methods to define and execute parsing grammars"
-optional = false
+optional = true
python-versions = ">=3.6.8"
files = [
{file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"},
@@ -2586,7 +2639,7 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"]
name = "scikit-image"
version = "0.22.0"
description = "Image processing in Python"
-optional = false
+optional = true
python-versions = ">=3.9"
files = [
{file = "scikit_image-0.22.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:74ec5c1d4693506842cc7c9487c89d8fc32aed064e9363def7af08b8f8cbb31d"},
@@ -2634,7 +2687,7 @@ test = ["asv", "matplotlib (>=3.5)", "numpydoc (>=1.5)", "pooch (>=1.6.0)", "pyt
name = "scipy"
version = "1.12.0"
description = "Fundamental algorithms for scientific computing in Python"
-optional = false
+optional = true
python-versions = ">=3.9"
files = [
{file = "scipy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:78e4402e140879387187f7f25d91cc592b3501a2e51dfb320f48dfb73565f10b"},
@@ -2839,7 +2892,7 @@ testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jar
name = "shapely"
version = "2.0.3"
description = "Manipulation and analysis of geometric objects"
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "shapely-2.0.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:af7e9abe180b189431b0f490638281b43b84a33a960620e6b2e8d3e3458b61a1"},
@@ -2988,31 +3041,6 @@ numpy = "*"
packaging = "*"
protobuf = ">=3.20"
-[[package]]
-name = "tensordict"
-version = "0.4.0+b4c91e8"
-description = ""
-optional = false
-python-versions = "*"
-files = []
-develop = false
-
-[package.dependencies]
-cloudpickle = "*"
-numpy = "*"
-torch = ">=2.1.0"
-
-[package.extras]
-checkpointing = ["torchsnapshot-nightly"]
-h5 = ["h5py (>=3.8)"]
-tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures", "pyyaml"]
-
-[package.source]
-type = "git"
-url = "https://github.com/pytorch/tensordict"
-reference = "HEAD"
-resolved_reference = "b4c91e8828c538ca0a50d8383fd99311a9afb078"
-
[[package]]
name = "termcolor"
version = "2.4.0"
@@ -3031,7 +3059,7 @@ tests = ["pytest", "pytest-cov"]
name = "tifffile"
version = "2024.2.12"
description = "Read and write TIFF files"
-optional = false
+optional = true
python-versions = ">=3.9"
files = [
{file = "tifffile-2024.2.12-py3-none-any.whl", hash = "sha256:870998f82fbc94ff7c3528884c1b0ae54863504ff51dbebea431ac3fa8fb7c21"},
@@ -3091,40 +3119,6 @@ type = "legacy"
url = "https://download.pytorch.org/whl/cpu"
reference = "torch-cpu"
-[[package]]
-name = "torchrl"
-version = "0.4.0+13bef42"
-description = ""
-optional = false
-python-versions = "*"
-files = []
-develop = false
-
-[package.dependencies]
-cloudpickle = "*"
-numpy = "*"
-packaging = "*"
-tensordict = ">=0.4.0"
-torch = ">=2.1.0"
-
-[package.extras]
-all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
-atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"]
-checkpointing = ["torchsnapshot"]
-dm-control = ["dm_control"]
-gym-continuous = ["gymnasium", "mujoco"]
-marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"]
-offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
-rendering = ["moviepy"]
-tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"]
-utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"]
-
-[package.source]
-type = "git"
-url = "https://github.com/pytorch/rl"
-reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
-resolved_reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
-
[[package]]
name = "torchvision"
version = "0.17.1+cpu"
@@ -3327,7 +3321,12 @@ files = [
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
+[extras]
+aloha = ["gym-aloha"]
+pusht = ["gym-pusht"]
+xarm = ["gym-xarm"]
+
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "8800bb8b24312d17b765cd2ce2799f49436171dd5fbf1bec3b07f853cfa9befd"
+content-hash = "8fa6dfc30e605741c24f5de58b89125d5b02153f550e5af7a44356956d6bb167"
diff --git a/.github/poetry/cpu/pyproject.toml b/.github/poetry/cpu/pyproject.toml
index e84b93c9..f5c439dc 100644
--- a/.github/poetry/cpu/pyproject.toml
+++ b/.github/poetry/cpu/pyproject.toml
@@ -23,7 +23,6 @@ packages = [{include = "lerobot"}]
python = "^3.10"
termcolor = "^2.4.0"
omegaconf = "^2.3.0"
-dm-env = "^1.6"
pandas = "^2.2.1"
wandb = "^0.16.3"
moviepy = "^1.0.3"
@@ -34,30 +33,41 @@ einops = "^0.7.0"
pygame = "^2.5.2"
pymunk = "^6.6.0"
zarr = "^2.17.0"
-shapely = "^2.0.3"
-scikit-image = "^0.22.0"
numba = "^0.59.0"
mpmath = "^1.3.0"
torch = {version = "^2.2.1", source = "torch-cpu"}
-tensordict = {git = "https://github.com/pytorch/tensordict"}
-torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
-mujoco = "^2.3.7"
opencv-python = "^4.9.0.80"
diffusers = "^0.26.3"
torchvision = {version = "^0.17.1", source = "torch-cpu"}
h5py = "^3.10.0"
-dm = "^1.3"
-dm-control = "1.0.14"
robomimic = "0.2.0"
huggingface-hub = "^0.21.4"
-gymnasium-robotics = "^1.2.4"
gymnasium = "^0.29.1"
cmake = "^3.29.0.1"
+gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
+gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
+gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true}
+# gym-pusht = { path = "../gym-pusht", develop = true, optional = true}
+# gym-xarm = { path = "../gym-xarm", develop = true, optional = true}
+# gym-aloha = { path = "../gym-aloha", develop = true, optional = true}
+
+
+[tool.poetry.extras]
+pusht = ["gym-pusht"]
+xarm = ["gym-xarm"]
+aloha = ["gym-aloha"]
+
+
+[tool.poetry.group.dev]
+optional = true
[tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2"
debugpy = "^1.8.1"
+
+
+[tool.poetry.group.test.dependencies]
pytest = "^8.1.0"
pytest-cov = "^5.0.0"
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 478be771..b3411e11 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -34,6 +34,11 @@ jobs:
with:
python-version: '3.10'
+ - name: Add SSH key for installing envs
+ uses: webfactory/ssh-agent@v0.9.0
+ with:
+ ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}
+
#----------------------------------------------
# install & configure poetry
#----------------------------------------------
@@ -87,7 +92,7 @@ jobs:
TMP: ~/tmp
run: |
mkdir ~/tmp
- poetry install --no-interaction --no-root
+ poetry install --no-interaction --no-root --all-extras
- name: Save cached venv
if: |
@@ -106,7 +111,7 @@ jobs:
# install project
#----------------------------------------------
- name: Install project
- run: poetry install --no-interaction
+ run: poetry install --no-interaction --all-extras
#----------------------------------------------
# run tests & coverage
@@ -137,6 +142,7 @@ jobs:
wandb.enable=False \
offline_steps=2 \
online_steps=0 \
+ eval_episodes=1 \
device=cpu \
save_model=true \
save_freq=2 \
@@ -154,17 +160,6 @@ jobs:
device=cpu \
policy.pretrained_model_path=tests/outputs/act/models/2.pt
- # TODO(aliberts): This takes ~2mn to run, needs to be improved
- # - name: Test eval ACT on ALOHA end-to-end (policy is None)
- # run: |
- # source .venv/bin/activate
- # python lerobot/scripts/eval.py \
- # --config lerobot/configs/default.yaml \
- # policy=act \
- # env=aloha \
- # eval_episodes=1 \
- # device=cpu
-
- name: Test train Diffusion on PushT end-to-end
run: |
source .venv/bin/activate
@@ -174,9 +169,11 @@ jobs:
wandb.enable=False \
offline_steps=2 \
online_steps=0 \
+ eval_episodes=1 \
device=cpu \
save_model=true \
save_freq=2 \
+ policy.batch_size=2 \
hydra.run.dir=tests/outputs/diffusion/
- name: Test eval Diffusion on PushT end-to-end
@@ -189,28 +186,20 @@ jobs:
device=cpu \
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
- - name: Test eval Diffusion on PushT end-to-end (policy is None)
- run: |
- source .venv/bin/activate
- python lerobot/scripts/eval.py \
- --config lerobot/configs/default.yaml \
- policy=diffusion \
- env=pusht \
- eval_episodes=1 \
- device=cpu
-
- name: Test train TDMPC on Simxarm end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/train.py \
policy=tdmpc \
- env=simxarm \
+ env=xarm \
wandb.enable=False \
offline_steps=1 \
online_steps=1 \
+ eval_episodes=1 \
device=cpu \
save_model=true \
save_freq=2 \
+ policy.batch_size=2 \
hydra.run.dir=tests/outputs/tdmpc/
- name: Test eval TDMPC on Simxarm end-to-end
@@ -222,13 +211,3 @@ jobs:
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
-
- - name: Test eval TDPMC on Simxarm end-to-end (policy is None)
- run: |
- source .venv/bin/activate
- python lerobot/scripts/eval.py \
- --config lerobot/configs/default.yaml \
- policy=tdmpc \
- env=simxarm \
- eval_episodes=1 \
- device=cpu
diff --git a/.gitignore b/.gitignore
index ad9892d4..3132aba0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,6 +11,9 @@ rl
nautilus/*.yaml
*.key
+# Slurm
+sbatch*.sh
+
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
diff --git a/README.md b/README.md
index 31fdde0a..25b8d1e4 100644
--- a/README.md
+++ b/README.md
@@ -62,21 +62,29 @@
Download our source code:
```bash
-git clone https://github.com/huggingface/lerobot.git
-cd lerobot
+git clone https://github.com/huggingface/lerobot.git && cd lerobot
```
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
```bash
-conda create -y -n lerobot python=3.10
-conda activate lerobot
+conda create -y -n lerobot python=3.10 && conda activate lerobot
```
-Then, install 🤗 LeRobot:
+Install 🤗 LeRobot:
```bash
python -m pip install .
```
+For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
+- [aloha](https://github.com/huggingface/gym-aloha)
+- [xarm](https://github.com/huggingface/gym-xarm)
+- [pusht](https://github.com/huggingface/gym-pusht)
+
+For instance, to install 🤗 LeRobot with aloha and pusht, use:
+```bash
+python -m pip install ".[aloha, pusht]"
+```
+
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
```bash
wandb login
@@ -89,11 +97,11 @@ wandb login
├── lerobot
| ├── configs # contains hydra yaml files with all options that you can override in the command line
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
-| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, simxarm.yaml
+| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, xarm.yaml
| | └── policy # various policies: act.yaml, diffusion.yaml, tdmpc.yaml
| ├── common # contains classes and utilities
-| | ├── datasets # various datasets of human demonstrations: aloha, pusht, simxarm
-| | ├── envs # various sim environments: aloha, pusht, simxarm
+| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
+| | ├── envs # various sim environments: aloha, pusht, xarm
| | └── policies # various policies: act, diffusion, tdmpc
| └── scripts # contains functions to execute via command line
| ├── visualize_dataset.py # load a dataset and render its demonstrations
@@ -112,34 +120,32 @@ wandb login
You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities:
```python
""" Copy pasted from `examples/1_visualize_dataset.py` """
+import os
+from pathlib import Path
+
import lerobot
from lerobot.common.datasets.aloha import AlohaDataset
-from torchrl.data.replay_buffers import SamplerWithoutReplacement
from lerobot.scripts.visualize_dataset import render_dataset
print(lerobot.available_datasets)
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
-# we use this sampler to sample 1 frame after the other
-sampler = SamplerWithoutReplacement(shuffle=False)
-
-dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler)
+# TODO(rcadene): remove DATA_DIR
+dataset = AlohaDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
video_paths = render_dataset(
dataset,
out_dir="outputs/visualize_dataset/example",
- max_num_samples=300,
- fps=50,
+ max_num_episodes=1,
)
print(video_paths)
-# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
+# ['outputs/visualize_dataset/example/episode_0.mp4']
```
Or you can achieve the same result by executing our script from the command line:
```bash
python lerobot/scripts/visualize_dataset.py \
-env=aloha \
-task=sim_sim_transfer_cube_human \
+env=pusht \
hydra.run.dir=outputs/visualize_dataset/example
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
```
@@ -198,21 +204,33 @@ pre-commit install
pre-commit
```
-### Add dependencies
+### Dependencies
Instead of using `pip` directly, we use `poetry` for development purposes to easily track our dependencies.
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
-Install the project with:
+Install the project with dev dependencies and all environments:
```bash
-poetry install
+poetry install --sync --with dev --all-extras
+```
+This command should be run when pulling code with and updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the dependencies.
+
+To selectively install environments (for example aloha and pusht) use:
+```bash
+poetry install --sync --with dev --extras "aloha pusht"
```
-Then, the equivalent of `pip install some-package`, would just be:
+The equivalent of `pip install some-package`, would just be:
```bash
poetry add some-package
```
+When changes are made to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
+```bash
+poetry lock --no-update
+```
+
+
**NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated to the CI. To do this, you should create a separate environment and add the new package there as well. For example:
```bash
# Add the new package to your main poetry env
diff --git a/examples/1_visualize_dataset.py b/examples/1_visualize_dataset.py
index f52ab76a..15e0e54d 100644
--- a/examples/1_visualize_dataset.py
+++ b/examples/1_visualize_dataset.py
@@ -1,24 +1,20 @@
import os
-
-from torchrl.data.replay_buffers import SamplerWithoutReplacement
+from pathlib import Path
import lerobot
-from lerobot.common.datasets.aloha import AlohaDataset
+from lerobot.common.datasets.pusht import PushtDataset
from lerobot.scripts.visualize_dataset import render_dataset
print(lerobot.available_datasets)
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
-# we use this sampler to sample 1 frame after the other
-sampler = SamplerWithoutReplacement(shuffle=False)
-
-dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler, root=os.environ.get("DATA_DIR"))
+# TODO(rcadene): remove DATA_DIR
+dataset = PushtDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
video_paths = render_dataset(
dataset,
out_dir="outputs/visualize_dataset/example",
- max_num_samples=300,
- fps=50,
+ max_num_episodes=1,
)
print(video_paths)
# ['outputs/visualize_dataset/example/episode_0.mp4']
diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py
index 01a4cf76..238f953d 100644
--- a/examples/3_train_policy.py
+++ b/examples/3_train_policy.py
@@ -9,9 +9,8 @@ from pathlib import Path
import torch
from omegaconf import OmegaConf
-from tqdm import trange
-from lerobot.common.datasets.factory import make_offline_buffer
+from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from lerobot.common.utils import init_hydra_config
@@ -37,19 +36,33 @@ policy = DiffusionPolicy(
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
- n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)
policy.train()
-offline_buffer = make_offline_buffer(cfg)
+dataset = make_dataset(cfg)
+
+# create dataloader for offline training
+dataloader = torch.utils.data.DataLoader(
+ dataset,
+ num_workers=4,
+ batch_size=cfg.policy.batch_size,
+ shuffle=True,
+ pin_memory=cfg.device != "cpu",
+ drop_last=True,
+)
+
+for step, batch in enumerate(dataloader):
+ info = policy(batch, step)
+
+ if step % cfg.log_freq == 0:
+ num_samples = (step + 1) * cfg.policy.batch_size
+ loss = info["loss"]
+ update_s = info["update_s"]
+ print(f"step:{step} samples:{num_samples} loss:{loss:.3f} update_time:{update_s:.3f}(seconds)")
-for offline_step in trange(cfg.offline_steps):
- train_info = policy.update(offline_buffer, offline_step)
- if offline_step % cfg.log_freq == 0:
- print(train_info)
# Save the policy, configuration, and normalization stats for later use.
policy.save(output_directory / "model.pt")
OmegaConf.save(cfg, output_directory / "config.yaml")
-torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
+torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")
diff --git a/lerobot/__init__.py b/lerobot/__init__.py
index 5cf8bdb8..8ab95df8 100644
--- a/lerobot/__init__.py
+++ b/lerobot/__init__.py
@@ -12,14 +12,11 @@ Example:
print(lerobot.available_policies)
```
-Note:
- When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
- 1. set the required class attributes:
- - for classes inheriting from `AbstractDataset`: `available_datasets`
- - for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- - for classes inheriting from `AbstractPolicy`: `name`
- 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- 3. update variables in `tests/test_available.py` by importing your new class
+When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
+- Set the required class attributes: `available_datasets`.
+- Set the required class attributes: `name`.
+- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
+- Update variables in `tests/test_available.py` by importing your new class
"""
from lerobot.__version__ import __version__ # noqa: F401
@@ -27,16 +24,16 @@ from lerobot.__version__ import __version__ # noqa: F401
available_envs = [
"aloha",
"pusht",
- "simxarm",
+ "xarm",
]
available_tasks_per_env = {
"aloha": [
- "sim_insertion",
- "sim_transfer_cube",
+ "AlohaInsertion-v0",
+ "AlohaTransferCube-v0",
],
- "pusht": ["pusht"],
- "simxarm": ["lift"],
+ "pusht": ["PushT-v0"],
+ "xarm": ["XarmLift-v0"],
}
available_datasets_per_env = {
@@ -47,7 +44,7 @@ available_datasets_per_env = {
"aloha_sim_transfer_cube_scripted",
],
"pusht": ["pusht"],
- "simxarm": ["xarm_lift_medium"],
+ "xarm": ["xarm_lift_medium"],
}
available_datasets = [dataset for env in available_envs for dataset in available_datasets_per_env[env]]
diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py
deleted file mode 100644
index e9e9c610..00000000
--- a/lerobot/common/datasets/abstract.py
+++ /dev/null
@@ -1,234 +0,0 @@
-import logging
-from copy import deepcopy
-from math import ceil
-from pathlib import Path
-from typing import Callable
-
-import einops
-import torch
-import torchrl
-import tqdm
-from huggingface_hub import snapshot_download
-from tensordict import TensorDict
-from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
-from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement
-from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
-from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
-from torchrl.envs.transforms.transforms import Compose
-
-HF_USER = "lerobot"
-
-
-class AbstractDataset(TensorDictReplayBuffer):
- """
- AbstractDataset represents a dataset in the context of imitation learning or reinforcement learning.
- This class is designed to be subclassed by concrete implementations that specify particular types of datasets.
- These implementations can vary based on the source of the data, the environment the data pertains to,
- or the specific kind of data manipulation applied.
-
- Note:
- - `TensorDictReplayBuffer` is the base class from which `AbstractDataset` inherits. It provides the foundational
- functionality for storing and retrieving `TensorDict`-like data.
- - `available_datasets` should be overridden by concrete subclasses to list the specific dataset variants supported.
- It is expected that these variants correspond to a HuggingFace dataset on the hub.
- For instance, the `AlohaDataset` which inherites from `AbstractDataset` has 4 available dataset variants:
- - [aloha_sim_transfer_cube_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
- - [aloha_sim_insertion_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
- - [aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
- - [aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
- - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
- 1. set the required class attributes:
- - for classes inheriting from `AbstractDataset`: `available_datasets`
- - for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- - for classes inheriting from `AbstractPolicy`: `name`
- 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- 3. update variables in `tests/test_available.py` by importing your new class
- """
-
- available_datasets: list[str] | None = None
-
- def __init__(
- self,
- dataset_id: str,
- version: str | None = None,
- batch_size: int | None = None,
- *,
- shuffle: bool = True,
- root: Path | None = None,
- pin_memory: bool = False,
- prefetch: int = None,
- sampler: Sampler | None = None,
- collate_fn: Callable | None = None,
- writer: Writer | None = None,
- transform: "torchrl.envs.Transform" = None,
- ):
- assert (
- self.available_datasets is not None
- ), "Subclasses of `AbstractDataset` should set the `available_datasets` class attribute."
- assert (
- dataset_id in self.available_datasets
- ), f"The provided dataset ({dataset_id}) is not on the list of available datasets {self.available_datasets}."
-
- self.dataset_id = dataset_id
- self.version = version
- self.shuffle = shuffle
- self.root = root if root is None else Path(root)
-
- if self.root is not None and self.version is not None:
- logging.warning(
- f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
- )
-
- storage = self._download_or_load_dataset()
-
- super().__init__(
- storage=storage,
- sampler=sampler,
- writer=ImmutableDatasetWriter() if writer is None else writer,
- collate_fn=_collate_id if collate_fn is None else collate_fn,
- pin_memory=pin_memory,
- prefetch=prefetch,
- batch_size=batch_size,
- transform=transform,
- )
-
- @property
- def stats_patterns(self) -> dict:
- return {
- ("observation", "state"): "b c -> c",
- ("observation", "image"): "b c h w -> c 1 1",
- ("action",): "b c -> c",
- }
-
- @property
- def image_keys(self) -> list:
- return [("observation", "image")]
-
- @property
- def num_cameras(self) -> int:
- return len(self.image_keys)
-
- @property
- def num_samples(self) -> int:
- return len(self)
-
- @property
- def num_episodes(self) -> int:
- return len(self._storage._storage["episode"].unique())
-
- @property
- def transform(self):
- return self._transform
-
- def set_transform(self, transform):
- if not isinstance(transform, Compose):
- # required since torchrl calls `len(self._transform)` downstream
- if isinstance(transform, list):
- self._transform = Compose(*transform)
- else:
- self._transform = Compose(transform)
- else:
- self._transform = transform
-
- def compute_or_load_stats(self, batch_size: int = 32) -> TensorDict:
- stats_path = self.data_dir / "stats.pth"
- if stats_path.exists():
- stats = torch.load(stats_path)
- else:
- logging.info(f"compute_stats and save to {stats_path}")
- stats = self._compute_stats(batch_size)
- torch.save(stats, stats_path)
- return stats
-
- def _download_or_load_dataset(self) -> torch.StorageBase:
- if self.root is None:
- self.data_dir = Path(
- snapshot_download(
- repo_id=f"{HF_USER}/{self.dataset_id}", repo_type="dataset", revision=self.version
- )
- )
- else:
- self.data_dir = self.root / self.dataset_id
- return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
-
- def _compute_stats(self, batch_size: int = 32):
- """Compute dataset statistics including minimum, maximum, mean, and standard deviation.
-
- TODO(alexander-soare): Add a num_batches argument which essentially allows one to use a subset of the
- full dataset (for handling very large datasets). The sampling would then have to be random
- (preferably without replacement). Both stats computation loops would ideally sample the same
- items.
- """
- rb = TensorDictReplayBuffer(
- storage=self._storage,
- batch_size=32,
- prefetch=True,
- # Note: Due to be refactored soon. The point is that we should go through the whole dataset.
- sampler=SamplerWithoutReplacement(drop_last=False, shuffle=False),
- )
-
- # mean and std will be computed incrementally while max and min will track the running value.
- mean, std, max, min = {}, {}, {}, {}
- for key in self.stats_patterns:
- mean[key] = torch.tensor(0.0).float()
- std[key] = torch.tensor(0.0).float()
- max[key] = torch.tensor(-float("inf")).float()
- min[key] = torch.tensor(float("inf")).float()
-
- # Compute mean, min, max.
- # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
- # surprises when rerunning the sampler.
- first_batch = None
- running_item_count = 0 # for online mean computation
- for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
- batch = rb.sample()
- this_batch_size = batch.batch_size[0]
- running_item_count += this_batch_size
- if first_batch is None:
- first_batch = deepcopy(batch)
- for key, pattern in self.stats_patterns.items():
- batch[key] = batch[key].float()
- # Numerically stable update step for mean computation.
- batch_mean = einops.reduce(batch[key], pattern, "mean")
- # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
- # the update step, N is the running item count, B is this batch size, x̄ is the running mean,
- # and x is the current batch mean. Some rearrangement is then required to avoid risking
- # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
- # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
- mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
- max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
- min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
-
- # Compute std.
- first_batch_ = None
- running_item_count = 0 # for online std computation
- for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))):
- batch = rb.sample()
- this_batch_size = batch.batch_size[0]
- running_item_count += this_batch_size
- # Sanity check to make sure the batches are still in the same order as before.
- if first_batch_ is None:
- first_batch_ = deepcopy(batch)
- for key in self.stats_patterns:
- assert torch.equal(first_batch_[key], first_batch[key])
- for key, pattern in self.stats_patterns.items():
- batch[key] = batch[key].float()
- # Numerically stable update step for mean computation (where the mean is over squared
- # residuals).See notes in the mean computation loop above.
- batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
- std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
-
- for key in self.stats_patterns:
- std[key] = torch.sqrt(std[key])
-
- stats = TensorDict({}, batch_size=[])
- for key in self.stats_patterns:
- stats[(*key, "mean")] = mean[key]
- stats[(*key, "std")] = std[key]
- stats[(*key, "max")] = max[key]
- stats[(*key, "min")] = min[key]
-
- if key[0] == "observation":
- # use same stats for the next observations
- stats[("next", *key)] = stats[key]
- return stats
diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py
index 031c2cd3..4b241ad8 100644
--- a/lerobot/common/datasets/aloha.py
+++ b/lerobot/common/datasets/aloha.py
@@ -1,26 +1,13 @@
import logging
from pathlib import Path
-from typing import Callable
import einops
import gdown
import h5py
import torch
-import torchrl
import tqdm
-from tensordict import TensorDict
-from torchrl.data.replay_buffers.samplers import Sampler
-from torchrl.data.replay_buffers.storages import TensorStorage
-from torchrl.data.replay_buffers.writers import Writer
-from lerobot.common.datasets.abstract import AbstractDataset
-
-DATASET_IDS = [
- "aloha_sim_insertion_human",
- "aloha_sim_insertion_scripted",
- "aloha_sim_transfer_cube_human",
- "aloha_sim_transfer_cube_scripted",
-]
+from lerobot.common.datasets.utils import load_data_with_delta_timestamps
FOLDER_URLS = {
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
@@ -66,7 +53,6 @@ CAMERAS = {
def download(data_dir, dataset_id):
- assert dataset_id in DATASET_IDS
assert dataset_id in FOLDER_URLS
assert dataset_id in EP48_URLS
assert dataset_id in EP49_URLS
@@ -80,51 +66,80 @@ def download(data_dir, dataset_id):
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
-class AlohaDataset(AbstractDataset):
- available_datasets = DATASET_IDS
+class AlohaDataset(torch.utils.data.Dataset):
+ available_datasets = [
+ "aloha_sim_insertion_human",
+ "aloha_sim_insertion_scripted",
+ "aloha_sim_transfer_cube_human",
+ "aloha_sim_transfer_cube_scripted",
+ ]
+ fps = 50
+ image_keys = ["observation.images.top"]
def __init__(
self,
dataset_id: str,
version: str | None = "v1.2",
- batch_size: int | None = None,
- *,
- shuffle: bool = True,
root: Path | None = None,
- pin_memory: bool = False,
- prefetch: int = None,
- sampler: Sampler | None = None,
- collate_fn: Callable | None = None,
- writer: Writer | None = None,
- transform: "torchrl.envs.Transform" = None,
+ transform: callable = None,
+ delta_timestamps: dict[list[float]] | None = None,
):
- super().__init__(
- dataset_id,
- version,
- batch_size,
- shuffle=shuffle,
- root=root,
- pin_memory=pin_memory,
- prefetch=prefetch,
- sampler=sampler,
- collate_fn=collate_fn,
- writer=writer,
- transform=transform,
- )
+ super().__init__()
+ self.dataset_id = dataset_id
+ self.version = version
+ self.root = root
+ self.transform = transform
+ self.delta_timestamps = delta_timestamps
+
+ self.data_dir = self.root / f"{self.dataset_id}"
+ if (self.data_dir / "data_dict.pth").exists() and (
+ self.data_dir / "data_ids_per_episode.pth"
+ ).exists():
+ self.data_dict = torch.load(self.data_dir / "data_dict.pth")
+ self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
+ else:
+ self._download_and_preproc_obsolete()
+ self.data_dir.mkdir(parents=True, exist_ok=True)
+ torch.save(self.data_dict, self.data_dir / "data_dict.pth")
+ torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
@property
- def stats_patterns(self) -> dict:
- d = {
- ("observation", "state"): "b c -> c",
- ("action",): "b c -> c",
- }
- for cam in CAMERAS[self.dataset_id]:
- d[("observation", "image", cam)] = "b c h w -> c 1 1"
- return d
+ def num_samples(self) -> int:
+ return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property
- def image_keys(self) -> list:
- return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
+ def num_episodes(self) -> int:
+ return len(self.data_ids_per_episode)
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, idx):
+ item = {}
+
+ # get episode id and timestamp of the sampled frame
+ current_ts = self.data_dict["timestamp"][idx].item()
+ episode = self.data_dict["episode"][idx].item()
+
+ for key in self.data_dict:
+ if self.delta_timestamps is not None and key in self.delta_timestamps:
+ data, is_pad = load_data_with_delta_timestamps(
+ self.data_dict,
+ self.data_ids_per_episode,
+ self.delta_timestamps,
+ key,
+ current_ts,
+ episode,
+ )
+ item[key] = data
+ item[f"{key}_is_pad"] = is_pad
+ else:
+ item[key] = self.data_dict[key][idx]
+
+ if self.transform is not None:
+ item = self.transform(item)
+
+ return item
def _download_and_preproc_obsolete(self):
assert self.root is not None
@@ -132,54 +147,61 @@ class AlohaDataset(AbstractDataset):
if not raw_dir.is_dir():
download(raw_dir, self.dataset_id)
- total_num_frames = 0
+ total_frames = 0
logging.info("Compute total number of frames to initialize offline buffer")
for ep_id in range(NUM_EPISODES[self.dataset_id]):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
- total_num_frames += ep["/action"].shape[0] - 1
- logging.info(f"{total_num_frames=}")
+ total_frames += ep["/action"].shape[0] - 1
+ logging.info(f"{total_frames=}")
- logging.info("Initialize and feed offline buffer")
- idxtd = 0
+ self.data_ids_per_episode = {}
+ ep_dicts = []
+
+ frame_idx = 0
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
- ep_num_frames = ep["/action"].shape[0]
+ num_frames = ep["/action"].shape[0]
# last step of demonstration is considered done
- done = torch.zeros(ep_num_frames, 1, dtype=torch.bool)
+ done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
- ep_td = TensorDict(
- {
- ("observation", "state"): state[:-1],
- "action": action[:-1],
- "episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
- "frame_id": torch.arange(0, ep_num_frames - 1, 1),
- ("next", "observation", "state"): state[1:],
- # TODO: compute reward and success
- # ("next", "reward"): reward[1:],
- ("next", "done"): done[1:],
- # ("next", "success"): success[1:],
- },
- batch_size=ep_num_frames - 1,
- )
+ ep_dict = {
+ "observation.state": state,
+ "action": action,
+ "episode": torch.tensor([ep_id] * num_frames),
+ "frame_id": torch.arange(0, num_frames, 1),
+ "timestamp": torch.arange(0, num_frames, 1) / self.fps,
+ # "next.observation.state": state,
+ # TODO(rcadene): compute reward and success
+ # "next.reward": reward[1:],
+ "next.done": done[1:],
+ # "next.success": success[1:],
+ }
for cam in CAMERAS[self.dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
- ep_td["observation", "image", cam] = image[:-1]
- ep_td["next", "observation", "image", cam] = image[1:]
+ ep_dict[f"observation.images.{cam}"] = image[:-1]
+ # ep_dict[f"next.observation.images.{cam}"] = image[1:]
- if ep_id == 0:
- # hack to initialize tensordict data structure to store episodes
- td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}")
+ assert isinstance(ep_id, int)
+ self.data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1)
+ assert len(self.data_ids_per_episode[ep_id]) == num_frames
- td_data[idxtd : idxtd + len(ep_td)] = ep_td
- idxtd = idxtd + len(ep_td)
+ ep_dicts.append(ep_dict)
- return TensorStorage(td_data.lock_())
+ frame_idx += num_frames
+
+ self.data_dict = {}
+
+ keys = ep_dicts[0].keys()
+ for key in keys:
+ self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
+
+ self.data_dict["index"] = torch.arange(0, total_frames, 1)
diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py
index 04077034..4ae161f6 100644
--- a/lerobot/common/datasets/factory.py
+++ b/lerobot/common/datasets/factory.py
@@ -3,8 +3,9 @@ import os
from pathlib import Path
import torch
-from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
+from torchvision.transforms import v2
+from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
@@ -13,61 +14,16 @@ from lerobot.common.transforms import NormalizeTransform, Prod
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
-def make_offline_buffer(
+def make_dataset(
cfg,
- overwrite_sampler=None,
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
normalize=True,
- overwrite_batch_size=None,
- overwrite_prefetch=None,
stats_path=None,
):
- if cfg.policy.balanced_sampling:
- assert cfg.online_steps > 0
- batch_size = None
- pin_memory = False
- prefetch = None
- else:
- assert cfg.online_steps == 0
- num_slices = cfg.policy.batch_size
- batch_size = cfg.policy.horizon * num_slices
- pin_memory = cfg.device == "cuda"
- prefetch = cfg.prefetch
+ if cfg.env.name == "xarm":
+ from lerobot.common.datasets.xarm import XarmDataset
- if overwrite_batch_size is not None:
- batch_size = overwrite_batch_size
-
- if overwrite_prefetch is not None:
- prefetch = overwrite_prefetch
-
- if overwrite_sampler is None:
- # TODO(rcadene): move batch_size outside
- num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
- # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
- # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
-
- if cfg.offline_prioritized_sampler:
- logging.info("use prioritized sampler for offline dataset")
- sampler = PrioritizedSliceSampler(
- max_capacity=100_000,
- alpha=cfg.policy.per_alpha,
- beta=cfg.policy.per_beta,
- num_slices=num_traj_per_batch,
- strict_length=False,
- )
- else:
- logging.info("use simple sampler for offline dataset")
- sampler = SliceSampler(
- num_slices=num_traj_per_batch,
- strict_length=False,
- )
- else:
- sampler = overwrite_sampler
-
- if cfg.env.name == "simxarm":
- from lerobot.common.datasets.simxarm import SimxarmDataset
-
- clsfunc = SimxarmDataset
+ clsfunc = XarmDataset
elif cfg.env.name == "pusht":
from lerobot.common.datasets.pusht import PushtDataset
@@ -81,56 +37,66 @@ def make_offline_buffer(
else:
raise ValueError(cfg.env.name)
- offline_buffer = clsfunc(
- dataset_id=cfg.dataset_id,
- sampler=sampler,
- batch_size=batch_size,
- root=DATA_DIR,
- pin_memory=pin_memory,
- prefetch=prefetch if isinstance(prefetch, int) else None,
- )
-
- if cfg.policy.name == "tdmpc":
- img_keys = []
- for key in offline_buffer.image_keys:
- img_keys.append(("next", *key))
- img_keys += offline_buffer.image_keys
- else:
- img_keys = offline_buffer.image_keys
-
+ transforms = None
if normalize:
- transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
-
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec
- stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
-
- # we only normalize the state and action, since the images are usually normalized inside the model for
- # now (except for tdmpc: see the following)
- in_keys = [("observation", "state"), ("action")]
-
- if cfg.policy.name == "tdmpc":
- # TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
- in_keys += img_keys
- # TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
- in_keys += [("next", *key) for key in img_keys]
- in_keys.append(("next", "observation", "state"))
-
- if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
- # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
- stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
- stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
- stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
- stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
-
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
- transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
- offline_buffer.set_transform(transforms)
+ if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
+ stats = {}
+ # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
+ stats["observation.state"] = {}
+ stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
+ stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
+ stats["action"] = {}
+ stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
+ stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
+ elif stats_path is None:
+ # instantiate a one frame dataset with light transform
+ stats_dataset = clsfunc(
+ dataset_id=cfg.dataset_id,
+ root=DATA_DIR,
+ transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
+ )
- if not overwrite_sampler:
- index = torch.arange(0, offline_buffer.num_samples, 1)
- sampler.extend(index)
+ # load stats if the file exists already or compute stats and save it
+ precomputed_stats_path = stats_dataset.data_dir / "stats.pth"
+ if precomputed_stats_path.exists():
+ stats = torch.load(precomputed_stats_path)
+ else:
+ logging.info(f"compute_stats and save to {precomputed_stats_path}")
+ stats = compute_stats(stats_dataset)
+ torch.save(stats, stats_path)
+ else:
+ stats = torch.load(stats_path)
- return offline_buffer
+ transforms = v2.Compose(
+ [
+ Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
+ NormalizeTransform(
+ stats,
+ in_keys=[
+ "observation.state",
+ "action",
+ ],
+ mode=normalization_mode,
+ ),
+ ]
+ )
+
+ delta_timestamps = cfg.policy.get("delta_timestamps")
+ if delta_timestamps is not None:
+ for key in delta_timestamps:
+ if isinstance(delta_timestamps[key], str):
+ delta_timestamps[key] = eval(delta_timestamps[key])
+
+ dataset = clsfunc(
+ dataset_id=cfg.dataset_id,
+ root=DATA_DIR,
+ delta_timestamps=delta_timestamps,
+ transform=transforms,
+ )
+
+ return dataset
diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py
index 624fb140..34d92daa 100644
--- a/lerobot/common/datasets/pusht.py
+++ b/lerobot/common/datasets/pusht.py
@@ -1,21 +1,11 @@
from pathlib import Path
-from typing import Callable
import einops
import numpy as np
-import pygame
-import pymunk
import torch
-import torchrl
import tqdm
-from tensordict import TensorDict
-from torchrl.data.replay_buffers.samplers import Sampler
-from torchrl.data.replay_buffers.storages import TensorStorage
-from torchrl.data.replay_buffers.writers import Writer
-from lerobot.common.datasets.abstract import AbstractDataset
-from lerobot.common.datasets.utils import download_and_extract_zip
-from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
+from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
# as define in env
@@ -25,97 +15,93 @@ PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
-def get_goal_pose_body(pose):
- mass = 1
- inertia = pymunk.moment_for_box(mass, (50, 100))
- body = pymunk.Body(mass, inertia)
- # preserving the legacy assignment order for compatibility
- # the order here doesn't matter somehow, maybe because CoM is aligned with body origin
- body.position = pose[:2].tolist()
- body.angle = pose[2]
- return body
+class PushtDataset(torch.utils.data.Dataset):
+ """
+ Arguments
+ ----------
+ delta_timestamps : dict[list[float]] | None, optional
+ Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image)
+ If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
+ """
-def add_segment(space, a, b, radius):
- shape = pymunk.Segment(space.static_body, a, b, radius)
- shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
- return shape
-
-
-def add_tee(
- space,
- position,
- angle,
- scale=30,
- color="LightSlateGray",
- mask=None,
-):
- if mask is None:
- mask = pymunk.ShapeFilter.ALL_MASKS()
- mass = 1
- length = 4
- vertices1 = [
- (-length * scale / 2, scale),
- (length * scale / 2, scale),
- (length * scale / 2, 0),
- (-length * scale / 2, 0),
- ]
- inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
- vertices2 = [
- (-scale / 2, scale),
- (-scale / 2, length * scale),
- (scale / 2, length * scale),
- (scale / 2, scale),
- ]
- inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
- body = pymunk.Body(mass, inertia1 + inertia2)
- shape1 = pymunk.Poly(body, vertices1)
- shape2 = pymunk.Poly(body, vertices2)
- shape1.color = pygame.Color(color)
- shape2.color = pygame.Color(color)
- shape1.filter = pymunk.ShapeFilter(mask=mask)
- shape2.filter = pymunk.ShapeFilter(mask=mask)
- body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
- body.position = position
- body.angle = angle
- body.friction = 1
- space.add(body, shape1, shape2)
- return body
-
-
-class PushtDataset(AbstractDataset):
available_datasets = ["pusht"]
+ fps = 10
+ image_keys = ["observation.image"]
def __init__(
self,
dataset_id: str,
version: str | None = "v1.2",
- batch_size: int | None = None,
- *,
- shuffle: bool = True,
root: Path | None = None,
- pin_memory: bool = False,
- prefetch: int = None,
- sampler: Sampler | None = None,
- collate_fn: Callable | None = None,
- writer: Writer | None = None,
- transform: "torchrl.envs.Transform" = None,
+ transform: callable = None,
+ delta_timestamps: dict[list[float]] | None = None,
):
- super().__init__(
- dataset_id,
- version,
- batch_size,
- shuffle=shuffle,
- root=root,
- pin_memory=pin_memory,
- prefetch=prefetch,
- sampler=sampler,
- collate_fn=collate_fn,
- writer=writer,
- transform=transform,
- )
+ super().__init__()
+ self.dataset_id = dataset_id
+ self.version = version
+ self.root = root
+ self.transform = transform
+ self.delta_timestamps = delta_timestamps
+
+ self.data_dir = self.root / f"{self.dataset_id}"
+ if (self.data_dir / "data_dict.pth").exists() and (
+ self.data_dir / "data_ids_per_episode.pth"
+ ).exists():
+ self.data_dict = torch.load(self.data_dir / "data_dict.pth")
+ self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
+ else:
+ self._download_and_preproc_obsolete()
+ self.data_dir.mkdir(parents=True, exist_ok=True)
+ torch.save(self.data_dict, self.data_dir / "data_dict.pth")
+ torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
+
+ @property
+ def num_samples(self) -> int:
+ return len(self.data_dict["index"]) if "index" in self.data_dict else 0
+
+ @property
+ def num_episodes(self) -> int:
+ return len(self.data_ids_per_episode)
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, idx):
+ item = {}
+
+ # get episode id and timestamp of the sampled frame
+ current_ts = self.data_dict["timestamp"][idx].item()
+ episode = self.data_dict["episode"][idx].item()
+
+ for key in self.data_dict:
+ if self.delta_timestamps is not None and key in self.delta_timestamps:
+ data, is_pad = load_data_with_delta_timestamps(
+ self.data_dict,
+ self.data_ids_per_episode,
+ self.delta_timestamps,
+ key,
+ current_ts,
+ episode,
+ )
+ item[key] = data
+ item[f"{key}_is_pad"] = is_pad
+ else:
+ item[key] = self.data_dict[key][idx]
+
+ if self.transform is not None:
+ item = self.transform(item)
+
+ return item
def _download_and_preproc_obsolete(self):
+ try:
+ import pymunk
+ from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
+ except ModuleNotFoundError as e:
+ print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
+ raise e
+
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
@@ -140,34 +126,37 @@ class PushtDataset(AbstractDataset):
# TODO: verify that goal pose is expected to be fixed
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
- goal_body = get_goal_pose_body(goal_pos_angle)
+ goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
imgs = torch.from_numpy(dataset_dict["img"])
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
states = torch.from_numpy(dataset_dict["state"])
actions = torch.from_numpy(dataset_dict["action"])
+ self.data_ids_per_episode = {}
+ ep_dicts = []
+
idx0 = 0
- idxtd = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
idx1 = dataset_dict.meta["episode_ends"][episode_id]
- # to create test artifact
- # idx1 = 51
num_frames = idx1 - idx0
assert (episode_ids[idx0:idx1] == episode_id).all()
image = imgs[idx0:idx1]
+ assert image.min() >= 0.0
+ assert image.max() <= 255.0
+ image = image.type(torch.uint8)
state = states[idx0:idx1]
agent_pos = state[:, :2]
block_pos = state[:, 2:4]
block_angle = state[:, 4]
- reward = torch.zeros(num_frames, 1)
- success = torch.zeros(num_frames, 1, dtype=torch.bool)
- done = torch.zeros(num_frames, 1, dtype=torch.bool)
+ reward = torch.zeros(num_frames)
+ success = torch.zeros(num_frames, dtype=torch.bool)
+ done = torch.zeros(num_frames, dtype=torch.bool)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0
@@ -175,14 +164,14 @@ class PushtDataset(AbstractDataset):
# Add walls.
walls = [
- add_segment(space, (5, 506), (5, 5), 2),
- add_segment(space, (5, 5), (506, 5), 2),
- add_segment(space, (506, 5), (506, 506), 2),
- add_segment(space, (5, 506), (506, 506), 2),
+ PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
+ PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
+ PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
+ PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
]
space.add(*walls)
- block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item())
+ block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
@@ -194,30 +183,32 @@ class PushtDataset(AbstractDataset):
# last step of demonstration is considered done
done[-1] = True
- ep_td = TensorDict(
- {
- ("observation", "image"): image[:-1],
- ("observation", "state"): agent_pos[:-1],
- "action": actions[idx0:idx1][:-1],
- "episode": episode_ids[idx0:idx1][:-1],
- "frame_id": torch.arange(0, num_frames - 1, 1),
- ("next", "observation", "image"): image[1:],
- ("next", "observation", "state"): agent_pos[1:],
- # TODO: verify that reward and done are aligned with image and agent_pos
- ("next", "reward"): reward[1:],
- ("next", "done"): done[1:],
- ("next", "success"): success[1:],
- },
- batch_size=num_frames - 1,
- )
+ ep_dict = {
+ "observation.image": image,
+ "observation.state": agent_pos,
+ "action": actions[idx0:idx1],
+ "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
+ "frame_id": torch.arange(0, num_frames, 1),
+ "timestamp": torch.arange(0, num_frames, 1) / self.fps,
+ # "next.observation.image": image[1:],
+ # "next.observation.state": agent_pos[1:],
+ # TODO(rcadene): verify that reward and done are aligned with image and agent_pos
+ "next.reward": torch.cat([reward[1:], reward[[-1]]]),
+ "next.done": torch.cat([done[1:], done[[-1]]]),
+ "next.success": torch.cat([success[1:], success[[-1]]]),
+ }
+ ep_dicts.append(ep_dict)
- if episode_id == 0:
- # hack to initialize tensordict data structure to store episodes
- td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
-
- td_data[idxtd : idxtd + len(ep_td)] = ep_td
+ assert isinstance(episode_id, int)
+ self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
+ assert len(self.data_ids_per_episode[episode_id]) == num_frames
idx0 = idx1
- idxtd = idxtd + len(ep_td)
- return TensorStorage(td_data.lock_())
+ self.data_dict = {}
+
+ keys = ep_dicts[0].keys()
+ for key in keys:
+ self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
+
+ self.data_dict["index"] = torch.arange(0, total_frames, 1)
diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py
deleted file mode 100644
index dc30e69e..00000000
--- a/lerobot/common/datasets/simxarm.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import pickle
-import zipfile
-from pathlib import Path
-from typing import Callable
-
-import torch
-import torchrl
-import tqdm
-from tensordict import TensorDict
-from torchrl.data.replay_buffers.samplers import (
- Sampler,
-)
-from torchrl.data.replay_buffers.storages import TensorStorage
-from torchrl.data.replay_buffers.writers import Writer
-
-from lerobot.common.datasets.abstract import AbstractDataset
-
-
-def download():
- raise NotImplementedError()
- import gdown
-
- url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
- download_path = "data.zip"
- gdown.download(url, download_path, quiet=False)
- print("Extracting...")
- with zipfile.ZipFile(download_path, "r") as zip_f:
- for member in zip_f.namelist():
- if member.startswith("data/xarm") and member.endswith(".pkl"):
- print(member)
- zip_f.extract(member=member)
- Path(download_path).unlink()
-
-
-class SimxarmDataset(AbstractDataset):
- available_datasets = [
- "xarm_lift_medium",
- ]
-
- def __init__(
- self,
- dataset_id: str,
- version: str | None = "v1.1",
- batch_size: int | None = None,
- *,
- shuffle: bool = True,
- root: Path | None = None,
- pin_memory: bool = False,
- prefetch: int = None,
- sampler: Sampler | None = None,
- collate_fn: Callable | None = None,
- writer: Writer | None = None,
- transform: "torchrl.envs.Transform" = None,
- ):
- super().__init__(
- dataset_id,
- version,
- batch_size,
- shuffle=shuffle,
- root=root,
- pin_memory=pin_memory,
- prefetch=prefetch,
- sampler=sampler,
- collate_fn=collate_fn,
- writer=writer,
- transform=transform,
- )
-
- def _download_and_preproc_obsolete(self):
- # assert self.root is not None
- # TODO(rcadene): finish download
- # download()
-
- dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
- print(f"Using offline dataset '{dataset_path}'")
- with open(dataset_path, "rb") as f:
- dataset_dict = pickle.load(f)
-
- total_frames = dataset_dict["actions"].shape[0]
-
- idx0 = 0
- idx1 = 0
- episode_id = 0
- for i in tqdm.tqdm(range(total_frames)):
- idx1 += 1
-
- if not dataset_dict["dones"][i]:
- continue
-
- num_frames = idx1 - idx0
-
- image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
- state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
- next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
- next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
- next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
- next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
-
- episode = TensorDict(
- {
- ("observation", "image"): image,
- ("observation", "state"): state,
- "action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
- "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
- "frame_id": torch.arange(0, num_frames, 1),
- ("next", "observation", "image"): next_image,
- ("next", "observation", "state"): next_state,
- ("next", "reward"): next_reward,
- ("next", "done"): next_done,
- },
- batch_size=num_frames,
- )
-
- if episode_id == 0:
- # hack to initialize tensordict data structure to store episodes
- td_data = (
- episode[0]
- .expand(total_frames)
- .memmap_like(self.root / f"{self.dataset_id}" / "replay_buffer")
- )
-
- td_data[idx0:idx1] = episode
-
- episode_id += 1
- idx0 = idx1
-
- return TensorStorage(td_data.lock_())
diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py
index 0ad43a65..e67d8a04 100644
--- a/lerobot/common/datasets/utils.py
+++ b/lerobot/common/datasets/utils.py
@@ -1,8 +1,12 @@
import io
import zipfile
+from copy import deepcopy
+from math import ceil
from pathlib import Path
+import einops
import requests
+import torch
import tqdm
@@ -28,3 +32,185 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
return True
else:
return False
+
+
+def load_data_with_delta_timestamps(
+ data_dict: dict[torch.Tensor],
+ data_ids_per_episode: dict[torch.Tensor],
+ delta_timestamps: list[float],
+ key: str,
+ current_ts: float,
+ episode: int,
+ tol: float = 0.04,
+):
+ """
+ Given a current timestamp (e.g. current_ts=0.6) and a list of timestamps differences (e.g. delta_timestamps=[-0.8, -0.2, 0, 0.2]),
+ this function compute the query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames of the specified modality (e.g. key="observation.image").
+
+ Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError.
+ When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp,
+ the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range.
+ For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode,
+ or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode.
+
+ Parameters:
+ - data_dict (dict): A dictionary containing the data, where each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
+ - data_ids_per_episode (dict): A dictionary where keys are episode identifiers and values are lists of indices corresponding to frames associated with each episode.
+ - delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible key to be retrieved. These deltas are added to the current_ts to form the query timestamps.
+ - key (str): The key specifying which data modality is to be retrieved from the data_dict.
+ - current_ts (float): The current timestamp to which the delta timestamps are added to form the query timestamps.
+ - episode (int): The identifier of the episode from which frames are to be retrieved.
+ - tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
+
+ Returns:
+ - tuple: A tuple containing two elements:
+ - The first element is the data retrieved from the specified modality based on the closest match to the query timestamps.
+ - The second element is a boolean array indicating which frames were considered as padding (True if the distance to the closest timestamp was greater than the tolerance level).
+
+ Raises:
+ - AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
+ """
+ # get indices of the frames associated to the episode, and their timestamps
+ ep_data_ids = data_ids_per_episode[episode]
+ ep_timestamps = data_dict["timestamp"][ep_data_ids]
+
+ # we make the assumption that the timestamps are sorted
+ ep_first_ts = ep_timestamps[0]
+ ep_last_ts = ep_timestamps[-1]
+
+ # get timestamps used as query to retrieve data of previous/future frames
+ delta_ts = delta_timestamps[key]
+ query_ts = current_ts + torch.tensor(delta_ts)
+
+ # compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
+ dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
+ min_, argmin_ = dist.min(1)
+
+ # get the indices of the data that are closest to the query timestamps
+ data_ids = ep_data_ids[argmin_]
+ # closest_ts = ep_timestamps[argmin_]
+
+ # get the data
+ data = data_dict[key][data_ids].clone()
+
+ # TODO(rcadene): synchronize timestamps + interpolation if needed
+
+ is_pad = min_ > tol
+
+ # check violated query timestamps are all outside the episode range
+ assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
+ f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
+ "This might be due to synchronization issues with timestamps during data collection."
+ )
+
+ return data, is_pad
+
+
+def get_stats_einops_patterns(dataset):
+ """These einops patterns will be used to aggregate batches and compute statistics."""
+ stats_patterns = {
+ "action": "b c -> c",
+ "observation.state": "b c -> c",
+ }
+ for key in dataset.image_keys:
+ stats_patterns[key] = "b c h w -> c 1 1"
+ return stats_patterns
+
+
+def compute_stats(dataset, batch_size=32, max_num_samples=None):
+ if max_num_samples is None:
+ max_num_samples = len(dataset)
+ else:
+ raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
+
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ num_workers=4,
+ batch_size=batch_size,
+ shuffle=False,
+ # pin_memory=cfg.device != "cpu",
+ drop_last=False,
+ )
+
+ # get einops patterns to aggregate batches and compute statistics
+ stats_patterns = get_stats_einops_patterns(dataset)
+
+ # mean and std will be computed incrementally while max and min will track the running value.
+ mean, std, max, min = {}, {}, {}, {}
+ for key in stats_patterns:
+ mean[key] = torch.tensor(0.0).float()
+ std[key] = torch.tensor(0.0).float()
+ max[key] = torch.tensor(-float("inf")).float()
+ min[key] = torch.tensor(float("inf")).float()
+
+ # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
+ # surprises when rerunning the sampler.
+ first_batch = None
+ running_item_count = 0 # for online mean computation
+ for i, batch in enumerate(
+ tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
+ ):
+ this_batch_size = len(batch["index"])
+ running_item_count += this_batch_size
+ if first_batch is None:
+ first_batch = deepcopy(batch)
+ for key, pattern in stats_patterns.items():
+ batch[key] = batch[key].float()
+ # Numerically stable update step for mean computation.
+ batch_mean = einops.reduce(batch[key], pattern, "mean")
+ # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
+ # the update step, N is the running item count, B is this batch size, x̄ is the running mean,
+ # and x is the current batch mean. Some rearrangement is then required to avoid risking
+ # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
+ # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
+ mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
+ max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
+ min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
+
+ if i == ceil(max_num_samples / batch_size) - 1:
+ break
+
+ first_batch_ = None
+ running_item_count = 0 # for online std computation
+ for i, batch in enumerate(
+ tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
+ ):
+ this_batch_size = len(batch["index"])
+ running_item_count += this_batch_size
+ # Sanity check to make sure the batches are still in the same order as before.
+ if first_batch_ is None:
+ first_batch_ = deepcopy(batch)
+ for key in stats_patterns:
+ assert torch.equal(first_batch_[key], first_batch[key])
+ for key, pattern in stats_patterns.items():
+ batch[key] = batch[key].float()
+ # Numerically stable update step for mean computation (where the mean is over squared
+ # residuals).See notes in the mean computation loop above.
+ batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
+ std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
+
+ if i == ceil(max_num_samples / batch_size) - 1:
+ break
+
+ for key in stats_patterns:
+ std[key] = torch.sqrt(std[key])
+
+ stats = {}
+ for key in stats_patterns:
+ stats[key] = {
+ "mean": mean[key],
+ "std": std[key],
+ "max": max[key],
+ "min": min[key],
+ }
+
+ return stats
+
+
+def cycle(iterable):
+ iterator = iter(iterable)
+ while True:
+ try:
+ yield next(iterator)
+ except StopIteration:
+ iterator = iter(iterable)
diff --git a/lerobot/common/datasets/xarm.py b/lerobot/common/datasets/xarm.py
new file mode 100644
index 00000000..0dfcc5c9
--- /dev/null
+++ b/lerobot/common/datasets/xarm.py
@@ -0,0 +1,163 @@
+import pickle
+import zipfile
+from pathlib import Path
+
+import torch
+import tqdm
+
+from lerobot.common.datasets.utils import load_data_with_delta_timestamps
+
+
+def download(raw_dir):
+ import gdown
+
+ raw_dir.mkdir(parents=True, exist_ok=True)
+ url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
+ zip_path = raw_dir / "data.zip"
+ gdown.download(url, str(zip_path), quiet=False)
+ print("Extracting...")
+ with zipfile.ZipFile(str(zip_path), "r") as zip_f:
+ for member in zip_f.namelist():
+ if member.startswith("data/xarm") and member.endswith(".pkl"):
+ print(member)
+ zip_f.extract(member=member)
+ zip_path.unlink()
+
+
+class XarmDataset(torch.utils.data.Dataset):
+ available_datasets = [
+ "xarm_lift_medium",
+ ]
+ fps = 15
+ image_keys = ["observation.image"]
+
+ def __init__(
+ self,
+ dataset_id: str,
+ version: str | None = "v1.1",
+ root: Path | None = None,
+ transform: callable = None,
+ delta_timestamps: dict[list[float]] | None = None,
+ ):
+ super().__init__()
+ self.dataset_id = dataset_id
+ self.version = version
+ self.root = root
+ self.transform = transform
+ self.delta_timestamps = delta_timestamps
+
+ self.data_dir = self.root / f"{self.dataset_id}"
+ if (self.data_dir / "data_dict.pth").exists() and (
+ self.data_dir / "data_ids_per_episode.pth"
+ ).exists():
+ self.data_dict = torch.load(self.data_dir / "data_dict.pth")
+ self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
+ else:
+ self._download_and_preproc_obsolete()
+ self.data_dir.mkdir(parents=True, exist_ok=True)
+ torch.save(self.data_dict, self.data_dir / "data_dict.pth")
+ torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
+
+ @property
+ def num_samples(self) -> int:
+ return len(self.data_dict["index"]) if "index" in self.data_dict else 0
+
+ @property
+ def num_episodes(self) -> int:
+ return len(self.data_ids_per_episode)
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, idx):
+ item = {}
+
+ # get episode id and timestamp of the sampled frame
+ current_ts = self.data_dict["timestamp"][idx].item()
+ episode = self.data_dict["episode"][idx].item()
+
+ for key in self.data_dict:
+ if self.delta_timestamps is not None and key in self.delta_timestamps:
+ data, is_pad = load_data_with_delta_timestamps(
+ self.data_dict,
+ self.data_ids_per_episode,
+ self.delta_timestamps,
+ key,
+ current_ts,
+ episode,
+ )
+ item[key] = data
+ item[f"{key}_is_pad"] = is_pad
+ else:
+ item[key] = self.data_dict[key][idx]
+
+ if self.transform is not None:
+ item = self.transform(item)
+
+ return item
+
+ def _download_and_preproc_obsolete(self):
+ assert self.root is not None
+ raw_dir = self.root / f"{self.dataset_id}_raw"
+ if not raw_dir.exists():
+ download(raw_dir)
+
+ dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
+ print(f"Using offline dataset '{dataset_path}'")
+ with open(dataset_path, "rb") as f:
+ dataset_dict = pickle.load(f)
+
+ total_frames = dataset_dict["actions"].shape[0]
+
+ self.data_ids_per_episode = {}
+ ep_dicts = []
+
+ idx0 = 0
+ idx1 = 0
+ episode_id = 0
+ for i in tqdm.tqdm(range(total_frames)):
+ idx1 += 1
+
+ if not dataset_dict["dones"][i]:
+ continue
+
+ num_frames = idx1 - idx0
+
+ image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
+ state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
+ action = torch.tensor(dataset_dict["actions"][idx0:idx1])
+ # TODO(rcadene): we have a missing last frame which is the observation when the env is done
+ # it is critical to have this frame for tdmpc to predict a "done observation/state"
+ # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
+ # next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
+ next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
+ next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
+
+ ep_dict = {
+ "observation.image": image,
+ "observation.state": state,
+ "action": action,
+ "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
+ "frame_id": torch.arange(0, num_frames, 1),
+ "timestamp": torch.arange(0, num_frames, 1) / self.fps,
+ # "next.observation.image": next_image,
+ # "next.observation.state": next_state,
+ "next.reward": next_reward,
+ "next.done": next_done,
+ }
+ ep_dicts.append(ep_dict)
+
+ assert isinstance(episode_id, int)
+ self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
+ assert len(self.data_ids_per_episode[episode_id]) == num_frames
+
+ idx0 = idx1
+ episode_id += 1
+
+ self.data_dict = {}
+
+ keys = ep_dicts[0].keys()
+ for key in keys:
+ self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
+
+ self.data_dict["index"] = torch.arange(0, total_frames, 1)
diff --git a/lerobot/common/envs/abstract.py b/lerobot/common/envs/abstract.py
deleted file mode 100644
index ea5ce3da..00000000
--- a/lerobot/common/envs/abstract.py
+++ /dev/null
@@ -1,92 +0,0 @@
-from collections import deque
-from typing import Optional
-
-from tensordict import TensorDict
-from torchrl.envs import EnvBase
-
-from lerobot.common.utils import set_global_seed
-
-
-class AbstractEnv(EnvBase):
- """
- Note:
- When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
- 1. set the required class attributes:
- - for classes inheriting from `AbstractDataset`: `available_datasets`
- - for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- - for classes inheriting from `AbstractPolicy`: `name`
- 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- 3. update variables in `tests/test_available.py` by importing your new class
- """
-
- name: str | None = None # same name should be used to instantiate the environment in factory.py
- available_tasks: list[str] | None = None # for instance: sim_insertion, sim_transfer_cube, pusht, lift
-
- def __init__(
- self,
- task,
- frame_skip: int = 1,
- from_pixels: bool = False,
- pixels_only: bool = False,
- image_size=None,
- seed=1337,
- device="cpu",
- num_prev_obs=1,
- num_prev_action=0,
- ):
- super().__init__(device=device, batch_size=[])
- assert self.name is not None, "Subclasses of `AbstractEnv` should set the `name` class attribute."
- assert (
- self.available_tasks is not None
- ), "Subclasses of `AbstractEnv` should set the `available_tasks` class attribute."
- assert (
- task in self.available_tasks
- ), f"The provided task ({task}) is not on the list of available tasks {self.available_tasks}."
-
- self.task = task
- self.frame_skip = frame_skip
- self.from_pixels = from_pixels
- self.pixels_only = pixels_only
- self.image_size = image_size
- self.num_prev_obs = num_prev_obs
- self.num_prev_action = num_prev_action
-
- if pixels_only:
- assert from_pixels
- if from_pixels:
- assert image_size
-
- self._make_env()
- self._make_spec()
-
- # self._next_seed will be used for the next reset. It is recommended that when self.set_seed is called
- # you store the return value in self._next_seed (it will be a new randomly generated seed).
- self._next_seed = seed
- # Don't store the result of this in self._next_seed, as we want to make sure that the first time
- # self._reset is called, we use seed.
- self.set_seed(seed)
-
- if self.num_prev_obs > 0:
- self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
- self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
- if self.num_prev_action > 0:
- raise NotImplementedError()
- # self._prev_action_queue = deque(maxlen=self.num_prev_action)
-
- def render(self, mode="rgb_array", width=640, height=480):
- raise NotImplementedError("Abstract method")
-
- def _reset(self, tensordict: Optional[TensorDict] = None):
- raise NotImplementedError("Abstract method")
-
- def _step(self, tensordict: TensorDict):
- raise NotImplementedError("Abstract method")
-
- def _make_env(self):
- raise NotImplementedError("Abstract method")
-
- def _make_spec(self):
- raise NotImplementedError("Abstract method")
-
- def _set_seed(self, seed: Optional[int]):
- set_global_seed(seed)
diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml
deleted file mode 100644
index 8002838c..00000000
--- a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml
+++ /dev/null
@@ -1,59 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml
deleted file mode 100644
index 05249ad2..00000000
--- a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml
+++ /dev/null
@@ -1,48 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml
deleted file mode 100644
index 511f7947..00000000
--- a/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml
+++ /dev/null
@@ -1,53 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml
deleted file mode 100644
index 2d85a47c..00000000
--- a/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml
+++ /dev/null
@@ -1,42 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/assets/scene.xml b/lerobot/common/envs/aloha/assets/scene.xml
deleted file mode 100644
index 0f61b8a5..00000000
--- a/lerobot/common/envs/aloha/assets/scene.xml
+++ /dev/null
@@ -1,38 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/assets/tabletop.stl b/lerobot/common/envs/aloha/assets/tabletop.stl
deleted file mode 100644
index 1c17d3f0..00000000
--- a/lerobot/common/envs/aloha/assets/tabletop.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:76a1571d1aa36520f2bd81c268991b99816c2a7819464d718e0fd9976fe30dce
-size 684
diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl
deleted file mode 100644
index ef1f3f35..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:df73ae5b9058e5d50a6409ac2ab687dade75053a86591bb5e23ab051dbf2d659
-size 83384
diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl
deleted file mode 100644
index 7eb8aefd..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:56fb3cc1236d4193106038adf8e457c7252ae9e86c7cee6dabf0578c53666358
-size 83384
diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl b/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl
deleted file mode 100644
index 4c2b3a1f..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:a4baacd9a64df1be60ea5e98f50f3c660e1b7a1fe9684aace6004c5058c09483
-size 42884
diff --git a/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl b/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl
deleted file mode 100644
index 8a30f7cc..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:a18a1601074d29ed1d546ead70cd18fbb063f1db7b5b96b9f0365be714f3136a
-size 3884
diff --git a/lerobot/common/envs/aloha/assets/vx300s_1_base.stl b/lerobot/common/envs/aloha/assets/vx300s_1_base.stl
deleted file mode 100644
index 9198e625..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_1_base.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:d100cafe656671ca8fde98fb6a4cf2d1b746995c51c61c25ad9ea2715635d146
-size 99984
diff --git a/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl b/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl
deleted file mode 100644
index ab3d9570..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:139745a74055cb0b23430bb5bc032bf68cf7bea5e4975c8f4c04107ae005f7f0
-size 63884
diff --git a/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl b/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl
deleted file mode 100644
index 3d6f663c..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:900f236320dd3d500870c5fde763b2d47502d51e043a5c377875e70237108729
-size 102984
diff --git a/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl b/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl
deleted file mode 100644
index 4eb249e7..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:4104fc54bbfb8a9b533029f1e7e3ade3d54d638372b3195daa0c98f57e0295b5
-size 49584
diff --git a/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl b/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl
deleted file mode 100644
index 34c76221..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:66814e27fa728056416e25e02e89eb7d34c51d51c51e7c3df873829037ddc6b8
-size 99884
diff --git a/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl b/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl
deleted file mode 100644
index 232fabf7..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:90eb145c85627968c3776ae6de23ccff7e112c9dd713c46bc9acdfdaa859a048
-size 70784
diff --git a/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl b/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl
deleted file mode 100644
index 946c3c86..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:786c1077bfd226f14219581b11d5f19464ca95b17132e0bb7532503568f5af90
-size 450084
diff --git a/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl b/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl
deleted file mode 100644
index 28d5bd76..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:d1275a93fe2157c83dbc095617fb7e672888bdd48ec070a35ef4ab9ebd9755b0
-size 31684
diff --git a/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl b/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl
deleted file mode 100644
index 5201d5ea..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:a4de62c9a2ed2c78433010e4c05530a1254b1774a7651967f406120c9bf8973e
-size 379484
diff --git a/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml b/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml
deleted file mode 100644
index 93037ab7..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml
+++ /dev/null
@@ -1,17 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/assets/vx300s_left.xml b/lerobot/common/envs/aloha/assets/vx300s_left.xml
deleted file mode 100644
index 3af6c235..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_left.xml
+++ /dev/null
@@ -1,59 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/assets/vx300s_right.xml b/lerobot/common/envs/aloha/assets/vx300s_right.xml
deleted file mode 100644
index 495df478..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_right.xml
+++ /dev/null
@@ -1,59 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/constants.py b/lerobot/common/envs/aloha/constants.py
deleted file mode 100644
index e582e5f3..00000000
--- a/lerobot/common/envs/aloha/constants.py
+++ /dev/null
@@ -1,163 +0,0 @@
-from pathlib import Path
-
-### Simulation envs fixed constants
-DT = 0.02 # 0.02 ms -> 1/0.2 = 50 hz
-FPS = 50
-
-
-JOINTS = [
- # absolute joint position
- "left_arm_waist",
- "left_arm_shoulder",
- "left_arm_elbow",
- "left_arm_forearm_roll",
- "left_arm_wrist_angle",
- "left_arm_wrist_rotate",
- # normalized gripper position 0: close, 1: open
- "left_arm_gripper",
- # absolute joint position
- "right_arm_waist",
- "right_arm_shoulder",
- "right_arm_elbow",
- "right_arm_forearm_roll",
- "right_arm_wrist_angle",
- "right_arm_wrist_rotate",
- # normalized gripper position 0: close, 1: open
- "right_arm_gripper",
-]
-
-ACTIONS = [
- # position and quaternion for end effector
- "left_arm_waist",
- "left_arm_shoulder",
- "left_arm_elbow",
- "left_arm_forearm_roll",
- "left_arm_wrist_angle",
- "left_arm_wrist_rotate",
- # normalized gripper position (0: close, 1: open)
- "left_arm_gripper",
- "right_arm_waist",
- "right_arm_shoulder",
- "right_arm_elbow",
- "right_arm_forearm_roll",
- "right_arm_wrist_angle",
- "right_arm_wrist_rotate",
- # normalized gripper position (0: close, 1: open)
- "right_arm_gripper",
-]
-
-
-START_ARM_POSE = [
- 0,
- -0.96,
- 1.16,
- 0,
- -0.3,
- 0,
- 0.02239,
- -0.02239,
- 0,
- -0.96,
- 1.16,
- 0,
- -0.3,
- 0,
- 0.02239,
- -0.02239,
-]
-
-ASSETS_DIR = Path(__file__).parent.resolve() / "assets" # note: absolute path
-
-# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
-MASTER_GRIPPER_POSITION_OPEN = 0.02417
-MASTER_GRIPPER_POSITION_CLOSE = 0.01244
-PUPPET_GRIPPER_POSITION_OPEN = 0.05800
-PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
-
-# Gripper joint limits (qpos[6])
-MASTER_GRIPPER_JOINT_OPEN = 0.3083
-MASTER_GRIPPER_JOINT_CLOSE = -0.6842
-PUPPET_GRIPPER_JOINT_OPEN = 1.4910
-PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
-
-MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
-
-############################ Helper functions ############################
-
-
-def normalize_master_gripper_position(x):
- return (x - MASTER_GRIPPER_POSITION_CLOSE) / (
- MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
- )
-
-
-def normalize_puppet_gripper_position(x):
- return (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
- PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
- )
-
-
-def unnormalize_master_gripper_position(x):
- return x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
-
-
-def unnormalize_puppet_gripper_position(x):
- return x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
-
-
-def convert_position_from_master_to_puppet(x):
- return unnormalize_puppet_gripper_position(normalize_master_gripper_position(x))
-
-
-def normalizer_master_gripper_joint(x):
- return (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
-
-
-def normalize_puppet_gripper_joint(x):
- return (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
-
-
-def unnormalize_master_gripper_joint(x):
- return x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
-
-
-def unnormalize_puppet_gripper_joint(x):
- return x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
-
-
-def convert_join_from_master_to_puppet(x):
- return unnormalize_puppet_gripper_joint(normalizer_master_gripper_joint(x))
-
-
-def normalize_master_gripper_velocity(x):
- return x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
-
-
-def normalize_puppet_gripper_velocity(x):
- return x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
-
-
-def convert_master_from_position_to_joint(x):
- return (
- normalize_master_gripper_position(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
- + MASTER_GRIPPER_JOINT_CLOSE
- )
-
-
-def convert_master_from_joint_to_position(x):
- return unnormalize_master_gripper_position(
- (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
- )
-
-
-def convert_puppet_from_position_to_join(x):
- return (
- normalize_puppet_gripper_position(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
- + PUPPET_GRIPPER_JOINT_CLOSE
- )
-
-
-def convert_puppet_from_joint_to_position(x):
- return unnormalize_puppet_gripper_position(
- (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
- )
diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py
deleted file mode 100644
index 8f907650..00000000
--- a/lerobot/common/envs/aloha/env.py
+++ /dev/null
@@ -1,298 +0,0 @@
-import importlib
-import logging
-from collections import deque
-from typing import Optional
-
-import einops
-import numpy as np
-import torch
-from dm_control import mujoco
-from dm_control.rl import control
-from tensordict import TensorDict
-from torchrl.data.tensor_specs import (
- BoundedTensorSpec,
- CompositeSpec,
- DiscreteTensorSpec,
- UnboundedContinuousTensorSpec,
-)
-
-from lerobot.common.envs.abstract import AbstractEnv
-from lerobot.common.envs.aloha.constants import (
- ACTIONS,
- ASSETS_DIR,
- DT,
- JOINTS,
-)
-from lerobot.common.envs.aloha.tasks.sim import BOX_POSE, InsertionTask, TransferCubeTask
-from lerobot.common.envs.aloha.tasks.sim_end_effector import (
- InsertionEndEffectorTask,
- TransferCubeEndEffectorTask,
-)
-from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
-from lerobot.common.utils import set_global_seed
-
-_has_gym = importlib.util.find_spec("gymnasium") is not None
-
-
-class AlohaEnv(AbstractEnv):
- name = "aloha"
- available_tasks = ["sim_insertion", "sim_transfer_cube"]
- _reset_warning_issued = False
-
- def __init__(
- self,
- task,
- frame_skip: int = 1,
- from_pixels: bool = False,
- pixels_only: bool = False,
- image_size=None,
- seed=1337,
- device="cpu",
- num_prev_obs=1,
- num_prev_action=0,
- ):
- super().__init__(
- task=task,
- frame_skip=frame_skip,
- from_pixels=from_pixels,
- pixels_only=pixels_only,
- image_size=image_size,
- seed=seed,
- device=device,
- num_prev_obs=num_prev_obs,
- num_prev_action=num_prev_action,
- )
-
- def _make_env(self):
- if not _has_gym:
- raise ImportError("Cannot import gymnasium.")
-
- if not self.from_pixels:
- raise NotImplementedError()
-
- self._env = self._make_env_task(self.task)
-
- def render(self, mode="rgb_array", width=640, height=480):
- # TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
- image = self._env.physics.render(height=height, width=width, camera_id="top")
- return image
-
- def _make_env_task(self, task_name):
- # time limit is controlled by StepCounter in env factory
- time_limit = float("inf")
-
- if "sim_transfer_cube" in task_name:
- xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
- physics = mujoco.Physics.from_xml_path(str(xml_path))
- task = TransferCubeTask(random=False)
- elif "sim_insertion" in task_name:
- xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
- physics = mujoco.Physics.from_xml_path(str(xml_path))
- task = InsertionTask(random=False)
- elif "sim_end_effector_transfer_cube" in task_name:
- raise NotImplementedError()
- xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
- physics = mujoco.Physics.from_xml_path(str(xml_path))
- task = TransferCubeEndEffectorTask(random=False)
- elif "sim_end_effector_insertion" in task_name:
- raise NotImplementedError()
- xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
- physics = mujoco.Physics.from_xml_path(str(xml_path))
- task = InsertionEndEffectorTask(random=False)
- else:
- raise NotImplementedError(task_name)
-
- env = control.Environment(
- physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False
- )
- return env
-
- def _format_raw_obs(self, raw_obs):
- if self.from_pixels:
- image = torch.from_numpy(raw_obs["images"]["top"].copy())
- image = einops.rearrange(image, "h w c -> c h w")
- assert image.dtype == torch.uint8
- obs = {"image": {"top": image}}
-
- if not self.pixels_only:
- obs["state"] = torch.from_numpy(raw_obs["qpos"]).type(torch.float32)
- else:
- # TODO(rcadene):
- raise NotImplementedError()
- # obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
-
- return obs
-
- def _reset(self, tensordict: Optional[TensorDict] = None):
- if tensordict is not None and not AlohaEnv._reset_warning_issued:
- logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
- AlohaEnv._reset_warning_issued = True
-
- # Seed the environment and update the seed to be used for the next reset.
- self._next_seed = self.set_seed(self._next_seed)
-
- # TODO(rcadene): do not use global variable for this
- if "sim_transfer_cube" in self.task:
- BOX_POSE[0] = sample_box_pose() # used in sim reset
- elif "sim_insertion" in self.task:
- BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
-
- raw_obs = self._env.reset()
-
- obs = self._format_raw_obs(raw_obs.observation)
-
- if self.num_prev_obs > 0:
- stacked_obs = {}
- if "image" in obs:
- self._prev_obs_image_queue = deque(
- [obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
- )
- stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
- if "state" in obs:
- self._prev_obs_state_queue = deque(
- [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
- )
- stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
- obs = stacked_obs
-
- td = TensorDict(
- {
- "observation": TensorDict(obs, batch_size=[]),
- "done": torch.tensor([False], dtype=torch.bool),
- },
- batch_size=[],
- )
-
- return td
-
- def _step(self, tensordict: TensorDict):
- td = tensordict
- action = td["action"].numpy()
- assert action.ndim == 1
- # TODO(rcadene): add info["is_success"] and info["success"] ?
-
- _, reward, _, raw_obs = self._env.step(action)
-
- # TODO(rcadene): add an enum
- success = done = reward == 4
- obs = self._format_raw_obs(raw_obs)
-
- if self.num_prev_obs > 0:
- stacked_obs = {}
- if "image" in obs:
- self._prev_obs_image_queue.append(obs["image"]["top"])
- stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
- if "state" in obs:
- self._prev_obs_state_queue.append(obs["state"])
- stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
- obs = stacked_obs
-
- td = TensorDict(
- {
- "observation": TensorDict(obs, batch_size=[]),
- "reward": torch.tensor([reward], dtype=torch.float32),
- # success and done are true when coverage > self.success_threshold in env
- "done": torch.tensor([done], dtype=torch.bool),
- "success": torch.tensor([success], dtype=torch.bool),
- },
- batch_size=[],
- )
- return td
-
- def _make_spec(self):
- obs = {}
- from omegaconf import OmegaConf
-
- if self.from_pixels:
- if isinstance(self.image_size, int):
- image_shape = (3, self.image_size, self.image_size)
- elif OmegaConf.is_list(self.image_size) or isinstance(self.image_size, list):
- assert len(self.image_size) == 3 # c h w
- assert self.image_size[0] == 3 # c is RGB
- image_shape = tuple(self.image_size)
- else:
- raise ValueError(self.image_size)
- if self.num_prev_obs > 0:
- image_shape = (self.num_prev_obs + 1, *image_shape)
-
- obs["image"] = {
- "top": BoundedTensorSpec(
- low=0,
- high=255,
- shape=image_shape,
- dtype=torch.uint8,
- device=self.device,
- )
- }
- if not self.pixels_only:
- state_shape = (len(JOINTS),)
- if self.num_prev_obs > 0:
- state_shape = (self.num_prev_obs + 1, *state_shape)
-
- obs["state"] = UnboundedContinuousTensorSpec(
- # TODO: add low and high bounds
- shape=state_shape,
- dtype=torch.float32,
- device=self.device,
- )
- else:
- # TODO(rcadene): add observation_space achieved_goal and desired_goal?
- state_shape = (len(JOINTS),)
- if self.num_prev_obs > 0:
- state_shape = (self.num_prev_obs + 1, *state_shape)
-
- obs["state"] = UnboundedContinuousTensorSpec(
- # TODO: add low and high bounds
- shape=state_shape,
- dtype=torch.float32,
- device=self.device,
- )
- self.observation_spec = CompositeSpec({"observation": obs})
-
- # TODO(rcadene): valid when controling end effector?
- # action_space = self._env.action_spec()
- # self.action_spec = BoundedTensorSpec(
- # low=action_space.minimum,
- # high=action_space.maximum,
- # shape=action_space.shape,
- # dtype=torch.float32,
- # device=self.device,
- # )
-
- # TODO(rcaene): add bounds (where are they????)
- self.action_spec = BoundedTensorSpec(
- shape=(len(ACTIONS)),
- low=-1,
- high=1,
- dtype=torch.float32,
- device=self.device,
- )
-
- self.reward_spec = UnboundedContinuousTensorSpec(
- shape=(1,),
- dtype=torch.float32,
- device=self.device,
- )
-
- self.done_spec = CompositeSpec(
- {
- "done": DiscreteTensorSpec(
- 2,
- shape=(1,),
- dtype=torch.bool,
- device=self.device,
- ),
- "success": DiscreteTensorSpec(
- 2,
- shape=(1,),
- dtype=torch.bool,
- device=self.device,
- ),
- }
- )
-
- def _set_seed(self, seed: Optional[int]):
- set_global_seed(seed)
- # TODO(rcadene): seed the env
- # self._env.seed(seed)
- logging.warning("Aloha env is not seeded")
diff --git a/lerobot/common/envs/aloha/tasks/sim.py b/lerobot/common/envs/aloha/tasks/sim.py
deleted file mode 100644
index ee1d0927..00000000
--- a/lerobot/common/envs/aloha/tasks/sim.py
+++ /dev/null
@@ -1,219 +0,0 @@
-import collections
-
-import numpy as np
-from dm_control.suite import base
-
-from lerobot.common.envs.aloha.constants import (
- START_ARM_POSE,
- normalize_puppet_gripper_position,
- normalize_puppet_gripper_velocity,
- unnormalize_puppet_gripper_position,
-)
-
-BOX_POSE = [None] # to be changed from outside
-
-"""
-Environment for simulated robot bi-manual manipulation, with joint position control
-Action space: [left_arm_qpos (6), # absolute joint position
- left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
- right_arm_qpos (6), # absolute joint position
- right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
-
-Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
- left_gripper_position (1), # normalized gripper position (0: close, 1: open)
- right_arm_qpos (6), # absolute joint position
- right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
- "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
- left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
- right_arm_qvel (6), # absolute joint velocity (rad)
- right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
- "images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
-"""
-
-
-class BimanualViperXTask(base.Task):
- def __init__(self, random=None):
- super().__init__(random=random)
-
- def before_step(self, action, physics):
- left_arm_action = action[:6]
- right_arm_action = action[7 : 7 + 6]
- normalized_left_gripper_action = action[6]
- normalized_right_gripper_action = action[7 + 6]
-
- left_gripper_action = unnormalize_puppet_gripper_position(normalized_left_gripper_action)
- right_gripper_action = unnormalize_puppet_gripper_position(normalized_right_gripper_action)
-
- full_left_gripper_action = [left_gripper_action, -left_gripper_action]
- full_right_gripper_action = [right_gripper_action, -right_gripper_action]
-
- env_action = np.concatenate(
- [left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action]
- )
- super().before_step(env_action, physics)
- return
-
- def initialize_episode(self, physics):
- """Sets the state of the environment at the start of each episode."""
- super().initialize_episode(physics)
-
- @staticmethod
- def get_qpos(physics):
- qpos_raw = physics.data.qpos.copy()
- left_qpos_raw = qpos_raw[:8]
- right_qpos_raw = qpos_raw[8:16]
- left_arm_qpos = left_qpos_raw[:6]
- right_arm_qpos = right_qpos_raw[:6]
- left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
- right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
- return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
-
- @staticmethod
- def get_qvel(physics):
- qvel_raw = physics.data.qvel.copy()
- left_qvel_raw = qvel_raw[:8]
- right_qvel_raw = qvel_raw[8:16]
- left_arm_qvel = left_qvel_raw[:6]
- right_arm_qvel = right_qvel_raw[:6]
- left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
- right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
- return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
-
- @staticmethod
- def get_env_state(physics):
- raise NotImplementedError
-
- def get_observation(self, physics):
- obs = collections.OrderedDict()
- obs["qpos"] = self.get_qpos(physics)
- obs["qvel"] = self.get_qvel(physics)
- obs["env_state"] = self.get_env_state(physics)
- obs["images"] = {}
- obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
- obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
- obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
-
- return obs
-
- def get_reward(self, physics):
- # return whether left gripper is holding the box
- raise NotImplementedError
-
-
-class TransferCubeTask(BimanualViperXTask):
- def __init__(self, random=None):
- super().__init__(random=random)
- self.max_reward = 4
-
- def initialize_episode(self, physics):
- """Sets the state of the environment at the start of each episode."""
- # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
- # reset qpos, control and box position
- with physics.reset_context():
- physics.named.data.qpos[:16] = START_ARM_POSE
- np.copyto(physics.data.ctrl, START_ARM_POSE)
- assert BOX_POSE[0] is not None
- physics.named.data.qpos[-7:] = BOX_POSE[0]
- # print(f"{BOX_POSE=}")
- super().initialize_episode(physics)
-
- @staticmethod
- def get_env_state(physics):
- env_state = physics.data.qpos.copy()[16:]
- return env_state
-
- def get_reward(self, physics):
- # return whether left gripper is holding the box
- all_contact_pairs = []
- for i_contact in range(physics.data.ncon):
- id_geom_1 = physics.data.contact[i_contact].geom1
- id_geom_2 = physics.data.contact[i_contact].geom2
- name_geom_1 = physics.model.id2name(id_geom_1, "geom")
- name_geom_2 = physics.model.id2name(id_geom_2, "geom")
- contact_pair = (name_geom_1, name_geom_2)
- all_contact_pairs.append(contact_pair)
-
- touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
- touch_table = ("red_box", "table") in all_contact_pairs
-
- reward = 0
- if touch_right_gripper:
- reward = 1
- if touch_right_gripper and not touch_table: # lifted
- reward = 2
- if touch_left_gripper: # attempted transfer
- reward = 3
- if touch_left_gripper and not touch_table: # successful transfer
- reward = 4
- return reward
-
-
-class InsertionTask(BimanualViperXTask):
- def __init__(self, random=None):
- super().__init__(random=random)
- self.max_reward = 4
-
- def initialize_episode(self, physics):
- """Sets the state of the environment at the start of each episode."""
- # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
- # reset qpos, control and box position
- with physics.reset_context():
- physics.named.data.qpos[:16] = START_ARM_POSE
- np.copyto(physics.data.ctrl, START_ARM_POSE)
- assert BOX_POSE[0] is not None
- physics.named.data.qpos[-7 * 2 :] = BOX_POSE[0] # two objects
- # print(f"{BOX_POSE=}")
- super().initialize_episode(physics)
-
- @staticmethod
- def get_env_state(physics):
- env_state = physics.data.qpos.copy()[16:]
- return env_state
-
- def get_reward(self, physics):
- # return whether peg touches the pin
- all_contact_pairs = []
- for i_contact in range(physics.data.ncon):
- id_geom_1 = physics.data.contact[i_contact].geom1
- id_geom_2 = physics.data.contact[i_contact].geom2
- name_geom_1 = physics.model.id2name(id_geom_1, "geom")
- name_geom_2 = physics.model.id2name(id_geom_2, "geom")
- contact_pair = (name_geom_1, name_geom_2)
- all_contact_pairs.append(contact_pair)
-
- touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
- touch_left_gripper = (
- ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- )
-
- peg_touch_table = ("red_peg", "table") in all_contact_pairs
- socket_touch_table = (
- ("socket-1", "table") in all_contact_pairs
- or ("socket-2", "table") in all_contact_pairs
- or ("socket-3", "table") in all_contact_pairs
- or ("socket-4", "table") in all_contact_pairs
- )
- peg_touch_socket = (
- ("red_peg", "socket-1") in all_contact_pairs
- or ("red_peg", "socket-2") in all_contact_pairs
- or ("red_peg", "socket-3") in all_contact_pairs
- or ("red_peg", "socket-4") in all_contact_pairs
- )
- pin_touched = ("red_peg", "pin") in all_contact_pairs
-
- reward = 0
- if touch_left_gripper and touch_right_gripper: # touch both
- reward = 1
- if (
- touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
- ): # grasp both
- reward = 2
- if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
- reward = 3
- if pin_touched: # successful insertion
- reward = 4
- return reward
diff --git a/lerobot/common/envs/aloha/tasks/sim_end_effector.py b/lerobot/common/envs/aloha/tasks/sim_end_effector.py
deleted file mode 100644
index d93c8330..00000000
--- a/lerobot/common/envs/aloha/tasks/sim_end_effector.py
+++ /dev/null
@@ -1,263 +0,0 @@
-import collections
-
-import numpy as np
-from dm_control.suite import base
-
-from lerobot.common.envs.aloha.constants import (
- PUPPET_GRIPPER_POSITION_CLOSE,
- START_ARM_POSE,
- normalize_puppet_gripper_position,
- normalize_puppet_gripper_velocity,
- unnormalize_puppet_gripper_position,
-)
-from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
-
-"""
-Environment for simulated robot bi-manual manipulation, with end-effector control.
-Action space: [left_arm_pose (7), # position and quaternion for end effector
- left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
- right_arm_pose (7), # position and quaternion for end effector
- right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
-
-Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
- left_gripper_position (1), # normalized gripper position (0: close, 1: open)
- right_arm_qpos (6), # absolute joint position
- right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
- "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
- left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
- right_arm_qvel (6), # absolute joint velocity (rad)
- right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
- "images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
-"""
-
-
-class BimanualViperXEndEffectorTask(base.Task):
- def __init__(self, random=None):
- super().__init__(random=random)
-
- def before_step(self, action, physics):
- a_len = len(action) // 2
- action_left = action[:a_len]
- action_right = action[a_len:]
-
- # set mocap position and quat
- # left
- np.copyto(physics.data.mocap_pos[0], action_left[:3])
- np.copyto(physics.data.mocap_quat[0], action_left[3:7])
- # right
- np.copyto(physics.data.mocap_pos[1], action_right[:3])
- np.copyto(physics.data.mocap_quat[1], action_right[3:7])
-
- # set gripper
- g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7])
- g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7])
- np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
-
- def initialize_robots(self, physics):
- # reset joint position
- physics.named.data.qpos[:16] = START_ARM_POSE
-
- # reset mocap to align with end effector
- # to obtain these numbers:
- # (1) make an ee_sim env and reset to the same start_pose
- # (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
- # get env._physics.named.data.xquat['vx300s_left/gripper_link']
- # repeat the same for right side
- np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
- np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
- # right
- np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
- np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
-
- # reset gripper control
- close_gripper_control = np.array(
- [
- PUPPET_GRIPPER_POSITION_CLOSE,
- -PUPPET_GRIPPER_POSITION_CLOSE,
- PUPPET_GRIPPER_POSITION_CLOSE,
- -PUPPET_GRIPPER_POSITION_CLOSE,
- ]
- )
- np.copyto(physics.data.ctrl, close_gripper_control)
-
- def initialize_episode(self, physics):
- """Sets the state of the environment at the start of each episode."""
- super().initialize_episode(physics)
-
- @staticmethod
- def get_qpos(physics):
- qpos_raw = physics.data.qpos.copy()
- left_qpos_raw = qpos_raw[:8]
- right_qpos_raw = qpos_raw[8:16]
- left_arm_qpos = left_qpos_raw[:6]
- right_arm_qpos = right_qpos_raw[:6]
- left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
- right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
- return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
-
- @staticmethod
- def get_qvel(physics):
- qvel_raw = physics.data.qvel.copy()
- left_qvel_raw = qvel_raw[:8]
- right_qvel_raw = qvel_raw[8:16]
- left_arm_qvel = left_qvel_raw[:6]
- right_arm_qvel = right_qvel_raw[:6]
- left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
- right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
- return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
-
- @staticmethod
- def get_env_state(physics):
- raise NotImplementedError
-
- def get_observation(self, physics):
- # note: it is important to do .copy()
- obs = collections.OrderedDict()
- obs["qpos"] = self.get_qpos(physics)
- obs["qvel"] = self.get_qvel(physics)
- obs["env_state"] = self.get_env_state(physics)
- obs["images"] = {}
- obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
- obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
- obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
- # used in scripted policy to obtain starting pose
- obs["mocap_pose_left"] = np.concatenate(
- [physics.data.mocap_pos[0], physics.data.mocap_quat[0]]
- ).copy()
- obs["mocap_pose_right"] = np.concatenate(
- [physics.data.mocap_pos[1], physics.data.mocap_quat[1]]
- ).copy()
-
- # used when replaying joint trajectory
- obs["gripper_ctrl"] = physics.data.ctrl.copy()
- return obs
-
- def get_reward(self, physics):
- raise NotImplementedError
-
-
-class TransferCubeEndEffectorTask(BimanualViperXEndEffectorTask):
- def __init__(self, random=None):
- super().__init__(random=random)
- self.max_reward = 4
-
- def initialize_episode(self, physics):
- """Sets the state of the environment at the start of each episode."""
- self.initialize_robots(physics)
- # randomize box position
- cube_pose = sample_box_pose()
- box_start_idx = physics.model.name2id("red_box_joint", "joint")
- np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
- # print(f"randomized cube position to {cube_position}")
-
- super().initialize_episode(physics)
-
- @staticmethod
- def get_env_state(physics):
- env_state = physics.data.qpos.copy()[16:]
- return env_state
-
- def get_reward(self, physics):
- # return whether left gripper is holding the box
- all_contact_pairs = []
- for i_contact in range(physics.data.ncon):
- id_geom_1 = physics.data.contact[i_contact].geom1
- id_geom_2 = physics.data.contact[i_contact].geom2
- name_geom_1 = physics.model.id2name(id_geom_1, "geom")
- name_geom_2 = physics.model.id2name(id_geom_2, "geom")
- contact_pair = (name_geom_1, name_geom_2)
- all_contact_pairs.append(contact_pair)
-
- touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
- touch_table = ("red_box", "table") in all_contact_pairs
-
- reward = 0
- if touch_right_gripper:
- reward = 1
- if touch_right_gripper and not touch_table: # lifted
- reward = 2
- if touch_left_gripper: # attempted transfer
- reward = 3
- if touch_left_gripper and not touch_table: # successful transfer
- reward = 4
- return reward
-
-
-class InsertionEndEffectorTask(BimanualViperXEndEffectorTask):
- def __init__(self, random=None):
- super().__init__(random=random)
- self.max_reward = 4
-
- def initialize_episode(self, physics):
- """Sets the state of the environment at the start of each episode."""
- self.initialize_robots(physics)
- # randomize peg and socket position
- peg_pose, socket_pose = sample_insertion_pose()
-
- def id2index(j_id):
- return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
-
- peg_start_id = physics.model.name2id("red_peg_joint", "joint")
- peg_start_idx = id2index(peg_start_id)
- np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
- # print(f"randomized cube position to {cube_position}")
-
- socket_start_id = physics.model.name2id("blue_socket_joint", "joint")
- socket_start_idx = id2index(socket_start_id)
- np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
- # print(f"randomized cube position to {cube_position}")
-
- super().initialize_episode(physics)
-
- @staticmethod
- def get_env_state(physics):
- env_state = physics.data.qpos.copy()[16:]
- return env_state
-
- def get_reward(self, physics):
- # return whether peg touches the pin
- all_contact_pairs = []
- for i_contact in range(physics.data.ncon):
- id_geom_1 = physics.data.contact[i_contact].geom1
- id_geom_2 = physics.data.contact[i_contact].geom2
- name_geom_1 = physics.model.id2name(id_geom_1, "geom")
- name_geom_2 = physics.model.id2name(id_geom_2, "geom")
- contact_pair = (name_geom_1, name_geom_2)
- all_contact_pairs.append(contact_pair)
-
- touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
- touch_left_gripper = (
- ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- )
-
- peg_touch_table = ("red_peg", "table") in all_contact_pairs
- socket_touch_table = (
- ("socket-1", "table") in all_contact_pairs
- or ("socket-2", "table") in all_contact_pairs
- or ("socket-3", "table") in all_contact_pairs
- or ("socket-4", "table") in all_contact_pairs
- )
- peg_touch_socket = (
- ("red_peg", "socket-1") in all_contact_pairs
- or ("red_peg", "socket-2") in all_contact_pairs
- or ("red_peg", "socket-3") in all_contact_pairs
- or ("red_peg", "socket-4") in all_contact_pairs
- )
- pin_touched = ("red_peg", "pin") in all_contact_pairs
-
- reward = 0
- if touch_left_gripper and touch_right_gripper: # touch both
- reward = 1
- if (
- touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
- ): # grasp both
- reward = 2
- if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
- reward = 3
- if pin_touched: # successful insertion
- reward = 4
- return reward
diff --git a/lerobot/common/envs/aloha/utils.py b/lerobot/common/envs/aloha/utils.py
deleted file mode 100644
index 5ac8b955..00000000
--- a/lerobot/common/envs/aloha/utils.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import numpy as np
-
-
-def sample_box_pose():
- x_range = [0.0, 0.2]
- y_range = [0.4, 0.6]
- z_range = [0.05, 0.05]
-
- ranges = np.vstack([x_range, y_range, z_range])
- cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
-
- cube_quat = np.array([1, 0, 0, 0])
- return np.concatenate([cube_position, cube_quat])
-
-
-def sample_insertion_pose():
- # Peg
- x_range = [0.1, 0.2]
- y_range = [0.4, 0.6]
- z_range = [0.05, 0.05]
-
- ranges = np.vstack([x_range, y_range, z_range])
- peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
-
- peg_quat = np.array([1, 0, 0, 0])
- peg_pose = np.concatenate([peg_position, peg_quat])
-
- # Socket
- x_range = [-0.2, -0.1]
- y_range = [0.4, 0.6]
- z_range = [0.05, 0.05]
-
- ranges = np.vstack([x_range, y_range, z_range])
- socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
-
- socket_quat = np.array([1, 0, 0, 0])
- socket_pose = np.concatenate([socket_position, socket_quat])
-
- return peg_pose, socket_pose
diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py
index 855e073b..d5571935 100644
--- a/lerobot/common/envs/factory.py
+++ b/lerobot/common/envs/factory.py
@@ -1,64 +1,42 @@
-from torchrl.envs import SerialEnv
-from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
+import importlib
+
+import gymnasium as gym
-def make_env(cfg, transform=None):
+def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
"""
- Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
- environments. The env therefore returns batches.`
+ Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and
+ returns batched observation, reward, terminated, truncated of `num_parallel_envs` items.
"""
-
kwargs = {
- "frame_skip": cfg.env.action_repeat,
- "from_pixels": cfg.env.from_pixels,
- "pixels_only": cfg.env.pixels_only,
- "image_size": cfg.env.image_size,
- "num_prev_obs": cfg.n_obs_steps - 1,
+ "obs_type": "pixels_agent_pos",
+ "render_mode": "rgb_array",
+ "max_episode_steps": cfg.env.episode_length,
+ "visualization_width": 384,
+ "visualization_height": 384,
}
- if cfg.env.name == "simxarm":
- from lerobot.common.envs.simxarm.env import SimxarmEnv
+ package_name = f"gym_{cfg.env.name}"
- kwargs["task"] = cfg.env.task
- clsfunc = SimxarmEnv
- elif cfg.env.name == "pusht":
- from lerobot.common.envs.pusht.env import PushtEnv
+ try:
+ importlib.import_module(package_name)
+ except ModuleNotFoundError as e:
+ print(
+ f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`"
+ )
+ raise e
- # assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
+ gym_handle = f"{package_name}/{cfg.env.task}"
- clsfunc = PushtEnv
- elif cfg.env.name == "aloha":
- from lerobot.common.envs.aloha.env import AlohaEnv
-
- kwargs["task"] = cfg.env.task
- clsfunc = AlohaEnv
+ if num_parallel_envs == 0:
+ # non-batched version of the env that returns an observation of shape (c)
+ env = gym.make(gym_handle, disable_env_checker=True, **kwargs)
else:
- raise ValueError(cfg.env.name)
-
- def _make_env(seed):
- nonlocal kwargs
- kwargs["seed"] = seed
- env = clsfunc(**kwargs)
-
- # limit rollout to max_steps
- env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
-
- if transform is not None:
- # useful to add normalization
- if isinstance(transform, Compose):
- for tf in transform:
- env.append_transform(tf.clone())
- elif isinstance(transform, Transform):
- env.append_transform(transform.clone())
- else:
- raise NotImplementedError()
-
- return env
-
- return SerialEnv(
- cfg.rollout_batch_size,
- create_env_fn=_make_env,
- create_env_kwargs=[
- {"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
- ],
- )
+ # batched version of the env that returns an observation of shape (b, c)
+ env = gym.vector.SyncVectorEnv(
+ [
+ lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
+ for _ in range(num_parallel_envs)
+ ]
+ )
+ return env
diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py
deleted file mode 100644
index 5f7fb2c3..00000000
--- a/lerobot/common/envs/pusht/env.py
+++ /dev/null
@@ -1,245 +0,0 @@
-import importlib
-import logging
-from collections import deque
-from typing import Optional
-
-import cv2
-import numpy as np
-import torch
-from tensordict import TensorDict
-from torchrl.data.tensor_specs import (
- BoundedTensorSpec,
- CompositeSpec,
- DiscreteTensorSpec,
- UnboundedContinuousTensorSpec,
-)
-from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
-
-from lerobot.common.envs.abstract import AbstractEnv
-from lerobot.common.utils import set_global_seed
-
-_has_gym = importlib.util.find_spec("gymnasium") is not None
-
-
-class PushtEnv(AbstractEnv):
- name = "pusht"
- available_tasks = ["pusht"]
- _reset_warning_issued = False
-
- def __init__(
- self,
- task="pusht",
- frame_skip: int = 1,
- from_pixels: bool = False,
- pixels_only: bool = False,
- image_size=None,
- seed=1337,
- device="cpu",
- num_prev_obs=1,
- num_prev_action=0,
- ):
- super().__init__(
- task=task,
- frame_skip=frame_skip,
- from_pixels=from_pixels,
- pixels_only=pixels_only,
- image_size=image_size,
- seed=seed,
- device=device,
- num_prev_obs=num_prev_obs,
- num_prev_action=num_prev_action,
- )
-
- def _make_env(self):
- if not _has_gym:
- raise ImportError("Cannot import gymnasium.")
-
- # TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
- # from lerobot.common.envs.pusht.pusht_env import PushTEnv
-
- if not self.from_pixels:
- raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
- from lerobot.common.envs.pusht.pusht_image_env import PushTImageEnv
-
- self._env = PushTImageEnv(render_size=self.image_size)
-
- def render(self, mode="rgb_array", width=96, height=96, with_marker=True):
- """
- with_marker adds a cursor showing the targeted action for the controller.
- """
- if width != height:
- raise NotImplementedError()
- tmp = self._env.render_size
- if width != self._env.render_size:
- self._env.render_cache = None
- self._env.render_size = width
- out = self._env.render(mode).copy()
- if with_marker and self._env.latest_action is not None:
- action = np.array(self._env.latest_action)
- coord = (action / 512 * self._env.render_size).astype(np.int32)
- marker_size = int(8 / 96 * self._env.render_size)
- thickness = int(1 / 96 * self._env.render_size)
- cv2.drawMarker(
- out,
- coord,
- color=(255, 0, 0),
- markerType=cv2.MARKER_CROSS,
- markerSize=marker_size,
- thickness=thickness,
- )
- self._env.render_size = tmp
- return out
-
- def _format_raw_obs(self, raw_obs):
- if self.from_pixels:
- image = torch.from_numpy(raw_obs["image"])
- obs = {"image": image}
-
- if not self.pixels_only:
- obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
- else:
- # TODO:
- obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
-
- return obs
-
- def _reset(self, tensordict: Optional[TensorDict] = None):
- if tensordict is not None and not PushtEnv._reset_warning_issued:
- logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
- PushtEnv._reset_warning_issued = True
-
- # Seed the environment and update the seed to be used for the next reset.
- self._next_seed = self.set_seed(self._next_seed)
- raw_obs = self._env.reset()
-
- obs = self._format_raw_obs(raw_obs)
-
- if self.num_prev_obs > 0:
- stacked_obs = {}
- if "image" in obs:
- self._prev_obs_image_queue = deque(
- [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
- )
- stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
- if "state" in obs:
- self._prev_obs_state_queue = deque(
- [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
- )
- stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
- obs = stacked_obs
-
- td = TensorDict(
- {
- "observation": TensorDict(obs, batch_size=[]),
- "done": torch.tensor([False], dtype=torch.bool),
- },
- batch_size=[],
- )
-
- return td
-
- def _step(self, tensordict: TensorDict):
- td = tensordict
- action = td["action"].numpy()
- assert action.ndim == 1
- # TODO(rcadene): add info["is_success"] and info["success"] ?
-
- raw_obs, reward, done, info = self._env.step(action)
-
- obs = self._format_raw_obs(raw_obs)
-
- if self.num_prev_obs > 0:
- stacked_obs = {}
- if "image" in obs:
- self._prev_obs_image_queue.append(obs["image"])
- stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
- if "state" in obs:
- self._prev_obs_state_queue.append(obs["state"])
- stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
- obs = stacked_obs
-
- td = TensorDict(
- {
- "observation": TensorDict(obs, batch_size=[]),
- "reward": torch.tensor([reward], dtype=torch.float32),
- # success and done are true when coverage > self.success_threshold in env
- "done": torch.tensor([done], dtype=torch.bool),
- "success": torch.tensor([done], dtype=torch.bool),
- },
- batch_size=[],
- )
- return td
-
- def _make_spec(self):
- obs = {}
- if self.from_pixels:
- image_shape = (3, self.image_size, self.image_size)
- if self.num_prev_obs > 0:
- image_shape = (self.num_prev_obs + 1, *image_shape)
-
- obs["image"] = BoundedTensorSpec(
- low=0,
- high=255,
- shape=image_shape,
- dtype=torch.uint8,
- device=self.device,
- )
- if not self.pixels_only:
- state_shape = self._env.observation_space["agent_pos"].shape
- if self.num_prev_obs > 0:
- state_shape = (self.num_prev_obs + 1, *state_shape)
-
- obs["state"] = BoundedTensorSpec(
- low=0,
- high=512,
- shape=state_shape,
- dtype=torch.float32,
- device=self.device,
- )
- else:
- # TODO(rcadene): add observation_space achieved_goal and desired_goal?
- state_shape = self._env.observation_space["observation"].shape
- if self.num_prev_obs > 0:
- state_shape = (self.num_prev_obs + 1, *state_shape)
-
- obs["state"] = UnboundedContinuousTensorSpec(
- # TODO:
- shape=state_shape,
- dtype=torch.float32,
- device=self.device,
- )
- self.observation_spec = CompositeSpec({"observation": obs})
-
- self.action_spec = _gym_to_torchrl_spec_transform(
- self._env.action_space,
- device=self.device,
- )
-
- self.reward_spec = UnboundedContinuousTensorSpec(
- shape=(1,),
- dtype=torch.float32,
- device=self.device,
- )
-
- self.done_spec = CompositeSpec(
- {
- "done": DiscreteTensorSpec(
- 2,
- shape=(1,),
- dtype=torch.bool,
- device=self.device,
- ),
- "success": DiscreteTensorSpec(
- 2,
- shape=(1,),
- dtype=torch.bool,
- device=self.device,
- ),
- }
- )
-
- def _set_seed(self, seed: Optional[int]):
- # Set global seed.
- set_global_seed(seed)
- # Set PushTImageEnv seed as it relies on it's own internal _seed attribute.
- self._env.seed(seed)
diff --git a/lerobot/common/envs/pusht/pusht_env.py b/lerobot/common/envs/pusht/pusht_env.py
deleted file mode 100644
index 6ef70aec..00000000
--- a/lerobot/common/envs/pusht/pusht_env.py
+++ /dev/null
@@ -1,378 +0,0 @@
-import collections
-
-import cv2
-import gymnasium as gym
-import numpy as np
-import pygame
-import pymunk
-import pymunk.pygame_util
-import shapely.geometry as sg
-import skimage.transform as st
-from gymnasium import spaces
-from pymunk.vec2d import Vec2d
-
-from lerobot.common.envs.pusht.pymunk_override import DrawOptions
-
-
-def pymunk_to_shapely(body, shapes):
- geoms = []
- for shape in shapes:
- if isinstance(shape, pymunk.shapes.Poly):
- verts = [body.local_to_world(v) for v in shape.get_vertices()]
- verts += [verts[0]]
- geoms.append(sg.Polygon(verts))
- else:
- raise RuntimeError(f"Unsupported shape type {type(shape)}")
- geom = sg.MultiPolygon(geoms)
- return geom
-
-
-class PushTEnv(gym.Env):
- metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10}
- reward_range = (0.0, 1.0)
-
- def __init__(
- self,
- legacy=True, # compatibility with original
- block_cog=None,
- damping=None,
- render_action=True,
- render_size=96,
- reset_to_state=None,
- ):
- self._seed = None
- self.seed()
- self.window_size = ws = 512 # The size of the PyGame window
- self.render_size = render_size
- self.sim_hz = 100
- # Local controller params.
- self.k_p, self.k_v = 100, 20 # PD control.z
- self.control_hz = self.metadata["video.frames_per_second"]
- # legcay set_state for data compatibility
- self.legacy = legacy
-
- # agent_pos, block_pos, block_angle
- self.observation_space = spaces.Box(
- low=np.array([0, 0, 0, 0, 0], dtype=np.float64),
- high=np.array([ws, ws, ws, ws, np.pi * 2], dtype=np.float64),
- shape=(5,),
- dtype=np.float64,
- )
-
- # positional goal for agent
- self.action_space = spaces.Box(
- low=np.array([0, 0], dtype=np.float64),
- high=np.array([ws, ws], dtype=np.float64),
- shape=(2,),
- dtype=np.float64,
- )
-
- self.block_cog = block_cog
- self.damping = damping
- self.render_action = render_action
-
- """
- If human-rendering is used, `self.window` will be a reference
- to the window that we draw to. `self.clock` will be a clock that is used
- to ensure that the environment is rendered at the correct framerate in
- human-mode. They will remain `None` until human-mode is used for the
- first time.
- """
- self.window = None
- self.clock = None
- self.screen = None
-
- self.space = None
- self.teleop = None
- self.render_buffer = None
- self.latest_action = None
- self.reset_to_state = reset_to_state
-
- def reset(self):
- seed = self._seed
- self._setup()
- if self.block_cog is not None:
- self.block.center_of_gravity = self.block_cog
- if self.damping is not None:
- self.space.damping = self.damping
-
- # use legacy RandomState for compatibility
- state = self.reset_to_state
- if state is None:
- rs = np.random.RandomState(seed=seed)
- state = np.array(
- [
- rs.randint(50, 450),
- rs.randint(50, 450),
- rs.randint(100, 400),
- rs.randint(100, 400),
- rs.randn() * 2 * np.pi - np.pi,
- ]
- )
- self._set_state(state)
-
- observation = self._get_obs()
- return observation
-
- def step(self, action):
- dt = 1.0 / self.sim_hz
- self.n_contact_points = 0
- n_steps = self.sim_hz // self.control_hz
- if action is not None:
- self.latest_action = action
- for _ in range(n_steps):
- # Step PD control.
- # self.agent.velocity = self.k_p * (act - self.agent.position) # P control works too.
- acceleration = self.k_p * (action - self.agent.position) + self.k_v * (
- Vec2d(0, 0) - self.agent.velocity
- )
- self.agent.velocity += acceleration * dt
-
- # Step physics.
- self.space.step(dt)
-
- # compute reward
- goal_body = self._get_goal_pose_body(self.goal_pose)
- goal_geom = pymunk_to_shapely(goal_body, self.block.shapes)
- block_geom = pymunk_to_shapely(self.block, self.block.shapes)
-
- intersection_area = goal_geom.intersection(block_geom).area
- goal_area = goal_geom.area
- coverage = intersection_area / goal_area
- reward = np.clip(coverage / self.success_threshold, 0, 1)
- done = coverage > self.success_threshold
-
- observation = self._get_obs()
- info = self._get_info()
-
- return observation, reward, done, info
-
- def render(self, mode):
- return self._render_frame(mode)
-
- def teleop_agent(self):
- TeleopAgent = collections.namedtuple("TeleopAgent", ["act"])
-
- def act(obs):
- act = None
- mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen)
- if self.teleop or (mouse_position - self.agent.position).length < 30:
- self.teleop = True
- act = mouse_position
- return act
-
- return TeleopAgent(act)
-
- def _get_obs(self):
- obs = np.array(
- tuple(self.agent.position) + tuple(self.block.position) + (self.block.angle % (2 * np.pi),)
- )
- return obs
-
- def _get_goal_pose_body(self, pose):
- mass = 1
- inertia = pymunk.moment_for_box(mass, (50, 100))
- body = pymunk.Body(mass, inertia)
- # preserving the legacy assignment order for compatibility
- # the order here doesn't matter somehow, maybe because CoM is aligned with body origin
- body.position = pose[:2].tolist()
- body.angle = pose[2]
- return body
-
- def _get_info(self):
- n_steps = self.sim_hz // self.control_hz
- n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps))
- info = {
- "pos_agent": np.array(self.agent.position),
- "vel_agent": np.array(self.agent.velocity),
- "block_pose": np.array(list(self.block.position) + [self.block.angle]),
- "goal_pose": self.goal_pose,
- "n_contacts": n_contact_points_per_step,
- }
- return info
-
- def _render_frame(self, mode):
- if self.window is None and mode == "human":
- pygame.init()
- pygame.display.init()
- self.window = pygame.display.set_mode((self.window_size, self.window_size))
- if self.clock is None and mode == "human":
- self.clock = pygame.time.Clock()
-
- canvas = pygame.Surface((self.window_size, self.window_size))
- canvas.fill((255, 255, 255))
- self.screen = canvas
-
- draw_options = DrawOptions(canvas)
-
- # Draw goal pose.
- goal_body = self._get_goal_pose_body(self.goal_pose)
- for shape in self.block.shapes:
- goal_points = [
- pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface)
- for v in shape.get_vertices()
- ]
- goal_points += [goal_points[0]]
- pygame.draw.polygon(canvas, self.goal_color, goal_points)
-
- # Draw agent and block.
- self.space.debug_draw(draw_options)
-
- if mode == "human":
- # The following line copies our drawings from `canvas` to the visible window
- self.window.blit(canvas, canvas.get_rect())
- pygame.event.pump()
- pygame.display.update()
-
- # the clock is already ticked during in step for "human"
-
- img = np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))
- img = cv2.resize(img, (self.render_size, self.render_size))
- if self.render_action and self.latest_action is not None:
- action = np.array(self.latest_action)
- coord = (action / 512 * 96).astype(np.int32)
- marker_size = int(8 / 96 * self.render_size)
- thickness = int(1 / 96 * self.render_size)
- cv2.drawMarker(
- img,
- coord,
- color=(255, 0, 0),
- markerType=cv2.MARKER_CROSS,
- markerSize=marker_size,
- thickness=thickness,
- )
- return img
-
- def close(self):
- if self.window is not None:
- pygame.display.quit()
- pygame.quit()
-
- def seed(self, seed=None):
- if seed is None:
- seed = np.random.randint(0, 25536)
- self._seed = seed
- self.np_random = np.random.default_rng(seed)
-
- def _handle_collision(self, arbiter, space, data):
- self.n_contact_points += len(arbiter.contact_point_set.points)
-
- def _set_state(self, state):
- if isinstance(state, np.ndarray):
- state = state.tolist()
- pos_agent = state[:2]
- pos_block = state[2:4]
- rot_block = state[4]
- self.agent.position = pos_agent
- # setting angle rotates with respect to center of mass
- # therefore will modify the geometric position
- # if not the same as CoM
- # therefore should be modified first.
- if self.legacy:
- # for compatibility with legacy data
- self.block.position = pos_block
- self.block.angle = rot_block
- else:
- self.block.angle = rot_block
- self.block.position = pos_block
-
- # Run physics to take effect
- self.space.step(1.0 / self.sim_hz)
-
- def _set_state_local(self, state_local):
- agent_pos_local = state_local[:2]
- block_pose_local = state_local[2:]
- tf_img_obj = st.AffineTransform(translation=self.goal_pose[:2], rotation=self.goal_pose[2])
- tf_obj_new = st.AffineTransform(translation=block_pose_local[:2], rotation=block_pose_local[2])
- tf_img_new = st.AffineTransform(matrix=tf_img_obj.params @ tf_obj_new.params)
- agent_pos_new = tf_img_new(agent_pos_local)
- new_state = np.array(list(agent_pos_new[0]) + list(tf_img_new.translation) + [tf_img_new.rotation])
- self._set_state(new_state)
- return new_state
-
- def _setup(self):
- self.space = pymunk.Space()
- self.space.gravity = 0, 0
- self.space.damping = 0
- self.teleop = False
- self.render_buffer = []
-
- # Add walls.
- walls = [
- self._add_segment((5, 506), (5, 5), 2),
- self._add_segment((5, 5), (506, 5), 2),
- self._add_segment((506, 5), (506, 506), 2),
- self._add_segment((5, 506), (506, 506), 2),
- ]
- self.space.add(*walls)
-
- # Add agent, block, and goal zone.
- self.agent = self.add_circle((256, 400), 15)
- self.block = self.add_tee((256, 300), 0)
- self.goal_color = pygame.Color("LightGreen")
- self.goal_pose = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
-
- # Add collision handling
- self.collision_handeler = self.space.add_collision_handler(0, 0)
- self.collision_handeler.post_solve = self._handle_collision
- self.n_contact_points = 0
-
- self.max_score = 50 * 100
- self.success_threshold = 0.95 # 95% coverage.
-
- def _add_segment(self, a, b, radius):
- shape = pymunk.Segment(self.space.static_body, a, b, radius)
- shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
- return shape
-
- def add_circle(self, position, radius):
- body = pymunk.Body(body_type=pymunk.Body.KINEMATIC)
- body.position = position
- body.friction = 1
- shape = pymunk.Circle(body, radius)
- shape.color = pygame.Color("RoyalBlue")
- self.space.add(body, shape)
- return body
-
- def add_box(self, position, height, width):
- mass = 1
- inertia = pymunk.moment_for_box(mass, (height, width))
- body = pymunk.Body(mass, inertia)
- body.position = position
- shape = pymunk.Poly.create_box(body, (height, width))
- shape.color = pygame.Color("LightSlateGray")
- self.space.add(body, shape)
- return body
-
- def add_tee(self, position, angle, scale=30, color="LightSlateGray", mask=None):
- if mask is None:
- mask = pymunk.ShapeFilter.ALL_MASKS()
- mass = 1
- length = 4
- vertices1 = [
- (-length * scale / 2, scale),
- (length * scale / 2, scale),
- (length * scale / 2, 0),
- (-length * scale / 2, 0),
- ]
- inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
- vertices2 = [
- (-scale / 2, scale),
- (-scale / 2, length * scale),
- (scale / 2, length * scale),
- (scale / 2, scale),
- ]
- inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
- body = pymunk.Body(mass, inertia1 + inertia2)
- shape1 = pymunk.Poly(body, vertices1)
- shape2 = pymunk.Poly(body, vertices2)
- shape1.color = pygame.Color(color)
- shape2.color = pygame.Color(color)
- shape1.filter = pymunk.ShapeFilter(mask=mask)
- shape2.filter = pymunk.ShapeFilter(mask=mask)
- body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
- body.position = position
- body.angle = angle
- body.friction = 1
- self.space.add(body, shape1, shape2)
- return body
diff --git a/lerobot/common/envs/pusht/pusht_image_env.py b/lerobot/common/envs/pusht/pusht_image_env.py
deleted file mode 100644
index 6547835a..00000000
--- a/lerobot/common/envs/pusht/pusht_image_env.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import numpy as np
-from gymnasium import spaces
-
-from lerobot.common.envs.pusht.pusht_env import PushTEnv
-
-
-class PushTImageEnv(PushTEnv):
- metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}
-
- # Note: legacy defaults to True for compatibility with original
- def __init__(self, legacy=True, block_cog=None, damping=None, render_size=96):
- super().__init__(
- legacy=legacy, block_cog=block_cog, damping=damping, render_size=render_size, render_action=False
- )
- ws = self.window_size
- self.observation_space = spaces.Dict(
- {
- "image": spaces.Box(low=0, high=1, shape=(3, render_size, render_size), dtype=np.float32),
- "agent_pos": spaces.Box(low=0, high=ws, shape=(2,), dtype=np.float32),
- }
- )
- self.render_cache = None
-
- def _get_obs(self):
- img = super()._render_frame(mode="rgb_array")
-
- agent_pos = np.array(self.agent.position)
- img_obs = np.moveaxis(img, -1, 0)
- obs = {"image": img_obs, "agent_pos": agent_pos}
-
- self.render_cache = img
-
- return obs
-
- def render(self, mode):
- assert mode == "rgb_array"
-
- if self.render_cache is None:
- self._get_obs()
-
- return self.render_cache
diff --git a/lerobot/common/envs/pusht/pymunk_override.py b/lerobot/common/envs/pusht/pymunk_override.py
deleted file mode 100644
index 7ad76237..00000000
--- a/lerobot/common/envs/pusht/pymunk_override.py
+++ /dev/null
@@ -1,244 +0,0 @@
-# ----------------------------------------------------------------------------
-# pymunk
-# Copyright (c) 2007-2016 Victor Blomqvist
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-# ----------------------------------------------------------------------------
-
-"""This submodule contains helper functions to help with quick prototyping
-using pymunk together with pygame.
-
-Intended to help with debugging and prototyping, not for actual production use
-in a full application. The methods contained in this module is opinionated
-about your coordinate system and not in any way optimized.
-"""
-
-__docformat__ = "reStructuredText"
-
-__all__ = [
- "DrawOptions",
- "get_mouse_pos",
- "to_pygame",
- "from_pygame",
- # "lighten",
- "positive_y_is_up",
-]
-
-from typing import Sequence, Tuple
-
-import numpy as np
-import pygame
-import pymunk
-from pymunk.space_debug_draw_options import SpaceDebugColor
-from pymunk.vec2d import Vec2d
-
-positive_y_is_up: bool = False
-"""Make increasing values of y point upwards.
-
-When True::
-
- y
- ^
- | . (3, 3)
- |
- | . (2, 2)
- |
- +------ > x
-
-When False::
-
- +------ > x
- |
- | . (2, 2)
- |
- | . (3, 3)
- v
- y
-
-"""
-
-
-class DrawOptions(pymunk.SpaceDebugDrawOptions):
- def __init__(self, surface: pygame.Surface) -> None:
- """Draw a pymunk.Space on a pygame.Surface object.
-
- Typical usage::
-
- >>> import pymunk
- >>> surface = pygame.Surface((10,10))
- >>> space = pymunk.Space()
- >>> options = pymunk.pygame_util.DrawOptions(surface)
- >>> space.debug_draw(options)
-
- You can control the color of a shape by setting shape.color to the color
- you want it drawn in::
-
- >>> c = pymunk.Circle(None, 10)
- >>> c.color = pygame.Color("pink")
-
- See pygame_util.demo.py for a full example
-
- Since pygame uses a coordinate system where y points down (in contrast
- to many other cases), you either have to make the physics simulation
- with Pymunk also behave in that way, or flip everything when you draw.
-
- The easiest is probably to just make the simulation behave the same
- way as Pygame does. In that way all coordinates used are in the same
- orientation and easy to reason about::
-
- >>> space = pymunk.Space()
- >>> space.gravity = (0, -1000)
- >>> body = pymunk.Body()
- >>> body.position = (0, 0) # will be positioned in the top left corner
- >>> space.debug_draw(options)
-
- To flip the drawing its possible to set the module property
- :py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
- the simulation upside down before drawing::
-
- >>> positive_y_is_up = True
- >>> body = pymunk.Body()
- >>> body.position = (0, 0)
- >>> # Body will be position in bottom left corner
-
- :Parameters:
- surface : pygame.Surface
- Surface that the objects will be drawn on
- """
- self.surface = surface
- super().__init__()
-
- def draw_circle(
- self,
- pos: Vec2d,
- angle: float,
- radius: float,
- outline_color: SpaceDebugColor,
- fill_color: SpaceDebugColor,
- ) -> None:
- p = to_pygame(pos, self.surface)
-
- pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
- pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius - 4), 0)
-
- # circle_edge = pos + Vec2d(radius, 0).rotated(angle)
- # p2 = to_pygame(circle_edge, self.surface)
- # line_r = 2 if radius > 20 else 1
- # pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)
-
- def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
- p1 = to_pygame(a, self.surface)
- p2 = to_pygame(b, self.surface)
-
- pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])
-
- def draw_fat_segment(
- self,
- a: Tuple[float, float],
- b: Tuple[float, float],
- radius: float,
- outline_color: SpaceDebugColor,
- fill_color: SpaceDebugColor,
- ) -> None:
- p1 = to_pygame(a, self.surface)
- p2 = to_pygame(b, self.surface)
-
- r = round(max(1, radius * 2))
- pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
- if r > 2:
- orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
- if orthog[0] == 0 and orthog[1] == 0:
- return
- scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
- orthog[0] = round(orthog[0] * scale)
- orthog[1] = round(orthog[1] * scale)
- points = [
- (p1[0] - orthog[0], p1[1] - orthog[1]),
- (p1[0] + orthog[0], p1[1] + orthog[1]),
- (p2[0] + orthog[0], p2[1] + orthog[1]),
- (p2[0] - orthog[0], p2[1] - orthog[1]),
- ]
- pygame.draw.polygon(self.surface, fill_color.as_int(), points)
- pygame.draw.circle(
- self.surface,
- fill_color.as_int(),
- (round(p1[0]), round(p1[1])),
- round(radius),
- )
- pygame.draw.circle(
- self.surface,
- fill_color.as_int(),
- (round(p2[0]), round(p2[1])),
- round(radius),
- )
-
- def draw_polygon(
- self,
- verts: Sequence[Tuple[float, float]],
- radius: float,
- outline_color: SpaceDebugColor,
- fill_color: SpaceDebugColor,
- ) -> None:
- ps = [to_pygame(v, self.surface) for v in verts]
- ps += [ps[0]]
-
- radius = 2
- pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)
-
- if radius > 0:
- for i in range(len(verts)):
- a = verts[i]
- b = verts[(i + 1) % len(verts)]
- self.draw_fat_segment(a, b, radius, fill_color, fill_color)
-
- def draw_dot(self, size: float, pos: Tuple[float, float], color: SpaceDebugColor) -> None:
- p = to_pygame(pos, self.surface)
- pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)
-
-
-def get_mouse_pos(surface: pygame.Surface) -> Tuple[int, int]:
- """Get position of the mouse pointer in pymunk coordinates."""
- p = pygame.mouse.get_pos()
- return from_pygame(p, surface)
-
-
-def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
- """Convenience method to convert pymunk coordinates to pygame surface
- local coordinates.
-
- Note that in case positive_y_is_up is False, this function won't actually do
- anything except converting the point to integers.
- """
- if positive_y_is_up:
- return round(p[0]), surface.get_height() - round(p[1])
- else:
- return round(p[0]), round(p[1])
-
-
-def from_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
- """Convenience method to convert pygame surface local coordinates to
- pymunk coordinates
- """
- return to_pygame(p, surface)
-
-
-def light_color(color: SpaceDebugColor):
- color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
- color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
- return color
diff --git a/lerobot/common/envs/simxarm/env.py b/lerobot/common/envs/simxarm/env.py
deleted file mode 100644
index b81bf499..00000000
--- a/lerobot/common/envs/simxarm/env.py
+++ /dev/null
@@ -1,237 +0,0 @@
-import importlib
-import logging
-from collections import deque
-from typing import Optional
-
-import einops
-import numpy as np
-import torch
-from tensordict import TensorDict
-from torchrl.data.tensor_specs import (
- BoundedTensorSpec,
- CompositeSpec,
- DiscreteTensorSpec,
- UnboundedContinuousTensorSpec,
-)
-from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
-
-from lerobot.common.envs.abstract import AbstractEnv
-from lerobot.common.utils import set_global_seed
-
-MAX_NUM_ACTIONS = 4
-
-_has_gym = importlib.util.find_spec("gymnasium") is not None
-
-
-class SimxarmEnv(AbstractEnv):
- name = "simxarm"
- available_tasks = ["lift"]
-
- def __init__(
- self,
- task,
- frame_skip: int = 1,
- from_pixels: bool = False,
- pixels_only: bool = False,
- image_size=None,
- seed=1337,
- device="cpu",
- num_prev_obs=0,
- num_prev_action=0,
- ):
- super().__init__(
- task=task,
- frame_skip=frame_skip,
- from_pixels=from_pixels,
- pixels_only=pixels_only,
- image_size=image_size,
- seed=seed,
- device=device,
- num_prev_obs=num_prev_obs,
- num_prev_action=num_prev_action,
- )
-
- def _make_env(self):
- if not _has_gym:
- raise ImportError("Cannot import gymnasium.")
-
- import gymnasium
-
- from lerobot.common.envs.simxarm.simxarm import TASKS
-
- if self.task not in TASKS:
- raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
-
- self._env = TASKS[self.task]["env"]()
-
- num_actions = len(TASKS[self.task]["action_space"])
- self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
- self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
- if "w" not in TASKS[self.task]["action_space"]:
- self._action_padding[-1] = 1.0
-
- def render(self, mode="rgb_array", width=384, height=384):
- return self._env.render(mode, width=width, height=height)
-
- def _format_raw_obs(self, raw_obs):
- if self.from_pixels:
- image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
- image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
- image = torch.tensor(image.copy(), dtype=torch.uint8)
-
- obs = {"image": image}
-
- if not self.pixels_only:
- obs["state"] = torch.tensor(self._env.robot_state, dtype=torch.float32)
- else:
- obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
-
- # obs = TensorDict(obs, batch_size=[])
- return obs
-
- def _reset(self, tensordict: Optional[TensorDict] = None):
- td = tensordict
- if td is None or td.is_empty():
- raw_obs = self._env.reset()
-
- obs = self._format_raw_obs(raw_obs)
-
- if self.num_prev_obs > 0:
- stacked_obs = {}
- if "image" in obs:
- self._prev_obs_image_queue = deque(
- [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
- )
- stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
- if "state" in obs:
- self._prev_obs_state_queue = deque(
- [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
- )
- stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
- obs = stacked_obs
-
- td = TensorDict(
- {
- "observation": TensorDict(obs, batch_size=[]),
- "done": torch.tensor([False], dtype=torch.bool),
- },
- batch_size=[],
- )
- else:
- raise NotImplementedError()
-
- return td
-
- def _step(self, tensordict: TensorDict):
- td = tensordict
- action = td["action"].numpy()
- # step expects shape=(4,) so we pad if necessary
- action = np.concatenate([action, self._action_padding])
- # TODO(rcadene): add info["is_success"] and info["success"] ?
- sum_reward = 0
-
- if action.ndim == 1:
- action = einops.repeat(action, "c -> t c", t=self.frame_skip)
- else:
- if self.frame_skip > 1:
- raise NotImplementedError()
-
- num_action_steps = action.shape[0]
- for i in range(num_action_steps):
- raw_obs, reward, done, info = self._env.step(action[i])
- sum_reward += reward
-
- obs = self._format_raw_obs(raw_obs)
-
- if self.num_prev_obs > 0:
- stacked_obs = {}
- if "image" in obs:
- self._prev_obs_image_queue.append(obs["image"])
- stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
- if "state" in obs:
- self._prev_obs_state_queue.append(obs["state"])
- stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
- obs = stacked_obs
-
- td = TensorDict(
- {
- "observation": self._format_raw_obs(raw_obs),
- "reward": torch.tensor([sum_reward], dtype=torch.float32),
- "done": torch.tensor([done], dtype=torch.bool),
- "success": torch.tensor([info["success"]], dtype=torch.bool),
- },
- batch_size=[],
- )
- return td
-
- def _make_spec(self):
- obs = {}
- if self.from_pixels:
- image_shape = (3, self.image_size, self.image_size)
- if self.num_prev_obs > 0:
- image_shape = (self.num_prev_obs + 1, *image_shape)
-
- obs["image"] = BoundedTensorSpec(
- low=0,
- high=255,
- shape=image_shape,
- dtype=torch.uint8,
- device=self.device,
- )
- if not self.pixels_only:
- state_shape = (len(self._env.robot_state),)
- if self.num_prev_obs > 0:
- state_shape = (self.num_prev_obs + 1, *state_shape)
-
- obs["state"] = UnboundedContinuousTensorSpec(
- shape=state_shape,
- dtype=torch.float32,
- device=self.device,
- )
- else:
- # TODO(rcadene): add observation_space achieved_goal and desired_goal?
- state_shape = self._env.observation_space["observation"].shape
- if self.num_prev_obs > 0:
- state_shape = (self.num_prev_obs + 1, *state_shape)
-
- obs["state"] = UnboundedContinuousTensorSpec(
- # TODO:
- shape=state_shape,
- dtype=torch.float32,
- device=self.device,
- )
- self.observation_spec = CompositeSpec({"observation": obs})
-
- self.action_spec = _gym_to_torchrl_spec_transform(
- self._action_space,
- device=self.device,
- )
-
- self.reward_spec = UnboundedContinuousTensorSpec(
- shape=(1,),
- dtype=torch.float32,
- device=self.device,
- )
-
- self.done_spec = CompositeSpec(
- {
- "done": DiscreteTensorSpec(
- 2,
- shape=(1,),
- dtype=torch.bool,
- device=self.device,
- ),
- "success": DiscreteTensorSpec(
- 2,
- shape=(1,),
- dtype=torch.bool,
- device=self.device,
- ),
- }
- )
-
- def _set_seed(self, seed: Optional[int]):
- set_global_seed(seed)
- self._seed = seed
- # TODO(aliberts): change self._reset so that it takes in a seed value
- logging.warning("simxarm env is not properly seeded")
diff --git a/lerobot/common/envs/simxarm/simxarm/__init__.py b/lerobot/common/envs/simxarm/simxarm/__init__.py
deleted file mode 100644
index 903d6042..00000000
--- a/lerobot/common/envs/simxarm/simxarm/__init__.py
+++ /dev/null
@@ -1,166 +0,0 @@
-from collections import OrderedDict, deque
-
-import gymnasium as gym
-import numpy as np
-from gymnasium.wrappers import TimeLimit
-
-from lerobot.common.envs.simxarm.simxarm.tasks.base import Base as Base
-from lerobot.common.envs.simxarm.simxarm.tasks.lift import Lift
-from lerobot.common.envs.simxarm.simxarm.tasks.peg_in_box import PegInBox
-from lerobot.common.envs.simxarm.simxarm.tasks.push import Push
-from lerobot.common.envs.simxarm.simxarm.tasks.reach import Reach
-
-TASKS = OrderedDict(
- (
- (
- "reach",
- {
- "env": Reach,
- "action_space": "xyz",
- "episode_length": 50,
- "description": "Reach a target location with the end effector",
- },
- ),
- (
- "push",
- {
- "env": Push,
- "action_space": "xyz",
- "episode_length": 50,
- "description": "Push a cube to a target location",
- },
- ),
- (
- "peg_in_box",
- {
- "env": PegInBox,
- "action_space": "xyz",
- "episode_length": 50,
- "description": "Insert a peg into a box",
- },
- ),
- (
- "lift",
- {
- "env": Lift,
- "action_space": "xyzw",
- "episode_length": 50,
- "description": "Lift a cube above a height threshold",
- },
- ),
- )
-)
-
-
-class SimXarmWrapper(gym.Wrapper):
- """
- A wrapper for the SimXarm environments. This wrapper is used to
- convert the action and observation spaces to the correct format.
- """
-
- def __init__(self, env, task, obs_mode, image_size, action_repeat, frame_stack=1, channel_last=False):
- super().__init__(env)
- self._env = env
- self.obs_mode = obs_mode
- self.image_size = image_size
- self.action_repeat = action_repeat
- self.frame_stack = frame_stack
- self._frames = deque([], maxlen=frame_stack)
- self.channel_last = channel_last
- self._max_episode_steps = task["episode_length"] // action_repeat
-
- image_shape = (
- (image_size, image_size, 3 * frame_stack)
- if channel_last
- else (3 * frame_stack, image_size, image_size)
- )
- if obs_mode == "state":
- self.observation_space = env.observation_space["observation"]
- elif obs_mode == "rgb":
- self.observation_space = gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8)
- elif obs_mode == "all":
- self.observation_space = gym.spaces.Dict(
- state=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32),
- rgb=gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8),
- )
- else:
- raise ValueError(f"Unknown obs_mode {obs_mode}. Must be one of [rgb, all, state]")
- self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(len(task["action_space"]),))
- self.action_padding = np.zeros(4 - len(task["action_space"]), dtype=np.float32)
- if "w" not in task["action_space"]:
- self.action_padding[-1] = 1.0
-
- def _render_obs(self):
- obs = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
- if not self.channel_last:
- obs = obs.transpose(2, 0, 1)
- return obs.copy()
-
- def _update_frames(self, reset=False):
- pixels = self._render_obs()
- self._frames.append(pixels)
- if reset:
- for _ in range(1, self.frame_stack):
- self._frames.append(pixels)
- assert len(self._frames) == self.frame_stack
-
- def transform_obs(self, obs, reset=False):
- if self.obs_mode == "state":
- return obs["observation"]
- elif self.obs_mode == "rgb":
- self._update_frames(reset=reset)
- rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
- return rgb_obs
- elif self.obs_mode == "all":
- self._update_frames(reset=reset)
- rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
- return OrderedDict((("rgb", rgb_obs), ("state", self.robot_state)))
- else:
- raise ValueError(f"Unknown obs_mode {self.obs_mode}. Must be one of [rgb, all, state]")
-
- def reset(self):
- return self.transform_obs(self._env.reset(), reset=True)
-
- def step(self, action):
- action = np.concatenate([action, self.action_padding])
- reward = 0.0
- for _ in range(self.action_repeat):
- obs, r, done, info = self._env.step(action)
- reward += r
- return self.transform_obs(obs), reward, done, info
-
- def render(self, mode="rgb_array", width=384, height=384, **kwargs):
- return self._env.render(mode, width=width, height=height)
-
- @property
- def state(self):
- return self._env.robot_state
-
-
-def make(task, obs_mode="state", image_size=84, action_repeat=1, frame_stack=1, channel_last=False, seed=0):
- """
- Create a new environment.
- Args:
- task (str): The task to create an environment for. Must be one of:
- - 'reach'
- - 'push'
- - 'peg-in-box'
- - 'lift'
- obs_mode (str): The observation mode to use. Must be one of:
- - 'state': Only state observations
- - 'rgb': RGB images
- - 'all': RGB images and state observations
- image_size (int): The size of the image observations
- action_repeat (int): The number of times to repeat the action
- seed (int): The random seed to use
- Returns:
- gym.Env: The environment
- """
- if task not in TASKS:
- raise ValueError(f"Unknown task {task}. Must be one of {list(TASKS.keys())}")
- env = TASKS[task]["env"]()
- env = TimeLimit(env, TASKS[task]["episode_length"])
- env = SimXarmWrapper(env, TASKS[task], obs_mode, image_size, action_repeat, frame_stack, channel_last)
- env.seed(seed)
-
- return env
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/__init__.py b/lerobot/common/envs/simxarm/simxarm/tasks/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/lift.xml b/lerobot/common/envs/simxarm/simxarm/tasks/assets/lift.xml
deleted file mode 100644
index 92231f92..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/lift.xml
+++ /dev/null
@@ -1,53 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/base_link.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/base_link.stl
deleted file mode 100644
index f1f52955..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/base_link.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:21fb81ae7fba19e3c6b2d2ca60c8051712ba273357287eb5a397d92d61c7a736
-size 1211434
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_inner.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_inner.stl
deleted file mode 100644
index 6cb88945..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_inner.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:be68ce180d11630a667a5f37f4dffcc3feebe4217d4bb3912c813b6d9ca3ec66
-size 3284
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_inner2.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_inner2.stl
deleted file mode 100644
index dab55ef5..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_inner2.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:2c6448552bf6b1c4f17334d686a5320ce051bcdfe31431edf69303d8a570d1de
-size 3284
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_outer.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_outer.stl
deleted file mode 100644
index 21cf11fa..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_outer.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:748b9e197e6521914f18d1f6383a36f211136b3f33f2ad2a8c11b9f921c2cf86
-size 6284
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_finger.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_finger.stl
deleted file mode 100644
index 6bf4e502..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_finger.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:a44756eb72f9c214cb37e61dc209cd7073fdff3e4271a7423476ef6fd090d2d4
-size 242684
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_inner_knuckle.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_inner_knuckle.stl
deleted file mode 100644
index 817c7e1d..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_inner_knuckle.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:e8e48692ad26837bb3d6a97582c89784d09948fc09bfe4e5a59017859ff04dac
-size 366284
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_outer_knuckle.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_outer_knuckle.stl
deleted file mode 100644
index 010c0f3b..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_outer_knuckle.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:501665812b08d67e764390db781e839adc6896a9540301d60adf606f57648921
-size 22284
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link1.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link1.stl
deleted file mode 100644
index f2b676f2..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link1.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:34b541122df84d2ef5fcb91b715eb19659dc15ad8d44a191dde481f780265636
-size 184184
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link2.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link2.stl
deleted file mode 100644
index bf93580c..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link2.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:61e641cd47c169ecef779683332e00e4914db729bf02dfb61bfbe69351827455
-size 225584
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link3.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link3.stl
deleted file mode 100644
index d316d233..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link3.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:9e2798e7946dd70046c95455d5ba96392d0b54a6069caba91dc4ca66e1379b42
-size 237084
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link4.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link4.stl
deleted file mode 100644
index f6d5fe94..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link4.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:c757fee95f873191a0633c355c07a360032960771cabbd7593a6cdb0f1ffb089
-size 243684
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link5.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link5.stl
deleted file mode 100644
index e037b8b9..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link5.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:715ad5787c5dab57589937fd47289882707b5e1eb997e340d567785b02f4ec90
-size 229084
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link6.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link6.stl
deleted file mode 100644
index 198c5300..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link6.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:85b320aa420497827223d16d492bba8de091173374e361396fc7a5dad7bdb0cb
-size 399384
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link7.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link7.stl
deleted file mode 100644
index ce9a39ac..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link7.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:97115d848fbf802cb770cd9be639ae2af993103b9d9bbb0c50c943c738a36f18
-size 231684
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link_base.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link_base.stl
deleted file mode 100644
index 110b9531..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link_base.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:f6fcbc18258090eb56c21cfb17baa5ae43abc98b1958cd366f3a73b9898fc7f0
-size 2106184
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_finger.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_finger.stl
deleted file mode 100644
index 03f26e9a..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_finger.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:c5dee87c7f37baf554b8456ebfe0b3e8ed0b22b8938bd1add6505c2ad6d32c7d
-size 242684
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_inner_knuckle.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_inner_knuckle.stl
deleted file mode 100644
index 8586f344..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_inner_knuckle.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:b41dd2c2c550281bf78d7cc6fa117b14786700e5c453560a0cb5fd6dfa0ffb3e
-size 366284
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_outer_knuckle.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_outer_knuckle.stl
deleted file mode 100644
index ae7afc25..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_outer_knuckle.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:75ca1107d0a42a0f03802a9a49cab48419b31851ee8935f8f1ca06be1c1c91e8
-size 22284
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/peg_in_box.xml b/lerobot/common/envs/simxarm/simxarm/tasks/assets/peg_in_box.xml
deleted file mode 100644
index 0f85459f..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/peg_in_box.xml
+++ /dev/null
@@ -1,74 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/push.xml b/lerobot/common/envs/simxarm/simxarm/tasks/assets/push.xml
deleted file mode 100644
index 42a78c8a..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/push.xml
+++ /dev/null
@@ -1,54 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/reach.xml b/lerobot/common/envs/simxarm/simxarm/tasks/assets/reach.xml
deleted file mode 100644
index ded6d209..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/reach.xml
+++ /dev/null
@@ -1,48 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/shared.xml b/lerobot/common/envs/simxarm/simxarm/tasks/assets/shared.xml
deleted file mode 100644
index ee56f8f0..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/shared.xml
+++ /dev/null
@@ -1,51 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/xarm.xml b/lerobot/common/envs/simxarm/simxarm/tasks/assets/xarm.xml
deleted file mode 100644
index 023474d6..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/xarm.xml
+++ /dev/null
@@ -1,88 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/base.py b/lerobot/common/envs/simxarm/simxarm/tasks/base.py
deleted file mode 100644
index b937b290..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/base.py
+++ /dev/null
@@ -1,145 +0,0 @@
-import os
-
-import mujoco
-import numpy as np
-from gymnasium_robotics.envs import robot_env
-
-from lerobot.common.envs.simxarm.simxarm.tasks import mocap
-
-
-class Base(robot_env.MujocoRobotEnv):
- """
- Superclass for all simxarm environments.
- Args:
- xml_name (str): name of the xml environment file
- gripper_rotation (list): initial rotation of the gripper (given as a quaternion)
- """
-
- def __init__(self, xml_name, gripper_rotation=None):
- if gripper_rotation is None:
- gripper_rotation = [0, 1, 0, 0]
- self.gripper_rotation = np.array(gripper_rotation, dtype=np.float32)
- self.center_of_table = np.array([1.655, 0.3, 0.63625])
- self.max_z = 1.2
- self.min_z = 0.2
- super().__init__(
- model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"),
- n_substeps=20,
- n_actions=4,
- initial_qpos={},
- )
-
- @property
- def dt(self):
- return self.n_substeps * self.model.opt.timestep
-
- @property
- def eef(self):
- return self._utils.get_site_xpos(self.model, self.data, "grasp")
-
- @property
- def obj(self):
- return self._utils.get_site_xpos(self.model, self.data, "object_site")
-
- @property
- def robot_state(self):
- gripper_angle = self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint")
- return np.concatenate([self.eef, gripper_angle])
-
- def is_success(self):
- return NotImplementedError()
-
- def get_reward(self):
- raise NotImplementedError()
-
- def _sample_goal(self):
- raise NotImplementedError()
-
- def get_obs(self):
- return self._get_obs()
-
- def _step_callback(self):
- self._mujoco.mj_forward(self.model, self.data)
-
- def _limit_gripper(self, gripper_pos, pos_ctrl):
- if gripper_pos[0] > self.center_of_table[0] - 0.105 + 0.15:
- pos_ctrl[0] = min(pos_ctrl[0], 0)
- if gripper_pos[0] < self.center_of_table[0] - 0.105 - 0.3:
- pos_ctrl[0] = max(pos_ctrl[0], 0)
- if gripper_pos[1] > self.center_of_table[1] + 0.3:
- pos_ctrl[1] = min(pos_ctrl[1], 0)
- if gripper_pos[1] < self.center_of_table[1] - 0.3:
- pos_ctrl[1] = max(pos_ctrl[1], 0)
- if gripper_pos[2] > self.max_z:
- pos_ctrl[2] = min(pos_ctrl[2], 0)
- if gripper_pos[2] < self.min_z:
- pos_ctrl[2] = max(pos_ctrl[2], 0)
- return pos_ctrl
-
- def _apply_action(self, action):
- assert action.shape == (4,)
- action = action.copy()
- pos_ctrl, gripper_ctrl = action[:3], action[3]
- pos_ctrl = self._limit_gripper(
- self._utils.get_site_xpos(self.model, self.data, "grasp"), pos_ctrl
- ) * (1 / self.n_substeps)
- gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl])
- mocap.apply_action(
- self.model,
- self._model_names,
- self.data,
- np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl]),
- )
-
- def _render_callback(self):
- self._mujoco.mj_forward(self.model, self.data)
-
- def _reset_sim(self):
- self.data.time = self.initial_time
- self.data.qpos[:] = np.copy(self.initial_qpos)
- self.data.qvel[:] = np.copy(self.initial_qvel)
- self._sample_goal()
- self._mujoco.mj_step(self.model, self.data, nstep=10)
- return True
-
- def _set_gripper(self, gripper_pos, gripper_rotation):
- self._utils.set_mocap_pos(self.model, self.data, "robot0:mocap", gripper_pos)
- self._utils.set_mocap_quat(self.model, self.data, "robot0:mocap", gripper_rotation)
- self._utils.set_joint_qpos(self.model, self.data, "right_outer_knuckle_joint", 0)
- self.data.qpos[10] = 0.0
- self.data.qpos[12] = 0.0
-
- def _env_setup(self, initial_qpos):
- for name, value in initial_qpos.items():
- self.data.set_joint_qpos(name, value)
- mocap.reset(self.model, self.data)
- mujoco.mj_forward(self.model, self.data)
- self._sample_goal()
- mujoco.mj_forward(self.model, self.data)
-
- def reset(self):
- self._reset_sim()
- return self._get_obs()
-
- def step(self, action):
- assert action.shape == (4,)
- assert self.action_space.contains(action), "{!r} ({}) invalid".format(action, type(action))
- self._apply_action(action)
- self._mujoco.mj_step(self.model, self.data, nstep=2)
- self._step_callback()
- obs = self._get_obs()
- reward = self.get_reward()
- done = False
- info = {"is_success": self.is_success(), "success": self.is_success()}
- return obs, reward, done, info
-
- def render(self, mode="rgb_array", width=384, height=384):
- self._render_callback()
- # HACK
- self.model.vis.global_.offwidth = width
- self.model.vis.global_.offheight = height
- return self.mujoco_renderer.render(mode)
-
- def close(self):
- if self.mujoco_renderer is not None:
- self.mujoco_renderer.close()
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/lift.py b/lerobot/common/envs/simxarm/simxarm/tasks/lift.py
deleted file mode 100644
index 0b11196c..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/lift.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import numpy as np
-
-from lerobot.common.envs.simxarm.simxarm import Base
-
-
-class Lift(Base):
- def __init__(self):
- self._z_threshold = 0.15
- super().__init__("lift")
-
- @property
- def z_target(self):
- return self._init_z + self._z_threshold
-
- def is_success(self):
- return self.obj[2] >= self.z_target
-
- def get_reward(self):
- reach_dist = np.linalg.norm(self.obj - self.eef)
- reach_dist_xy = np.linalg.norm(self.obj[:-1] - self.eef[:-1])
- pick_completed = self.obj[2] >= (self.z_target - 0.01)
- obj_dropped = (self.obj[2] < (self._init_z + 0.005)) and (reach_dist > 0.02)
-
- # Reach
- if reach_dist < 0.05:
- reach_reward = -reach_dist + max(self._action[-1], 0) / 50
- elif reach_dist_xy < 0.05:
- reach_reward = -reach_dist
- else:
- z_bonus = np.linalg.norm(np.linalg.norm(self.obj[-1] - self.eef[-1]))
- reach_reward = -reach_dist - 2 * z_bonus
-
- # Pick
- if pick_completed and not obj_dropped:
- pick_reward = self.z_target
- elif (reach_dist < 0.1) and (self.obj[2] > (self._init_z + 0.005)):
- pick_reward = min(self.z_target, self.obj[2])
- else:
- pick_reward = 0
-
- return reach_reward / 100 + pick_reward
-
- def _get_obs(self):
- eef_velp = self._utils.get_site_xvelp(self.model, self.data, "grasp") * self.dt
- gripper_angle = self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint")
- eef = self.eef - self.center_of_table
-
- obj = self.obj - self.center_of_table
- obj_rot = self._utils.get_joint_qpos(self.model, self.data, "object_joint0")[-4:]
- obj_velp = self._utils.get_site_xvelp(self.model, self.data, "object_site") * self.dt
- obj_velr = self._utils.get_site_xvelr(self.model, self.data, "object_site") * self.dt
-
- obs = np.concatenate(
- [
- eef,
- eef_velp,
- obj,
- obj_rot,
- obj_velp,
- obj_velr,
- eef - obj,
- np.array(
- [
- np.linalg.norm(eef - obj),
- np.linalg.norm(eef[:-1] - obj[:-1]),
- self.z_target,
- self.z_target - obj[-1],
- self.z_target - eef[-1],
- ]
- ),
- gripper_angle,
- ],
- axis=0,
- )
- return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": eef}
-
- def _sample_goal(self):
- # Gripper
- gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3)
- super()._set_gripper(gripper_pos, self.gripper_rotation)
-
- # Object
- object_pos = self.center_of_table - np.array([0.15, 0.10, 0.07])
- object_pos[0] += self.np_random.uniform(-0.05, 0.05, size=1)
- object_pos[1] += self.np_random.uniform(-0.05, 0.05, size=1)
- object_qpos = self._utils.get_joint_qpos(self.model, self.data, "object_joint0")
- object_qpos[:3] = object_pos
- self._utils.set_joint_qpos(self.model, self.data, "object_joint0", object_qpos)
- self._init_z = object_pos[2]
-
- # Goal
- return object_pos + np.array([0, 0, self._z_threshold])
-
- def reset(self):
- self._action = np.zeros(4)
- return super().reset()
-
- def step(self, action):
- self._action = action.copy()
- return super().step(action)
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/mocap.py b/lerobot/common/envs/simxarm/simxarm/tasks/mocap.py
deleted file mode 100644
index 4295bf19..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/mocap.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# import mujoco_py
-import mujoco
-import numpy as np
-
-
-def apply_action(model, model_names, data, action):
- if model.nmocap > 0:
- pos_action, gripper_action = np.split(action, (model.nmocap * 7,))
- if data.ctrl is not None:
- for i in range(gripper_action.shape[0]):
- data.ctrl[i] = gripper_action[i]
- pos_action = pos_action.reshape(model.nmocap, 7)
- pos_delta, quat_delta = pos_action[:, :3], pos_action[:, 3:]
- reset_mocap2body_xpos(model, model_names, data)
- data.mocap_pos[:] = data.mocap_pos + pos_delta
- data.mocap_quat[:] = data.mocap_quat + quat_delta
-
-
-def reset(model, data):
- if model.nmocap > 0 and model.eq_data is not None:
- for i in range(model.eq_data.shape[0]):
- # if sim.model.eq_type[i] == mujoco_py.const.EQ_WELD:
- if model.eq_type[i] == mujoco.mjtEq.mjEQ_WELD:
- # model.eq_data[i, :] = np.array([0., 0., 0., 1., 0., 0., 0.])
- model.eq_data[i, :] = np.array(
- [
- 0.0,
- 0.0,
- 0.0,
- 1.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- ]
- )
- # sim.forward()
- mujoco.mj_forward(model, data)
-
-
-def reset_mocap2body_xpos(model, model_names, data):
- if model.eq_type is None or model.eq_obj1id is None or model.eq_obj2id is None:
- return
-
- # For all weld constraints
- for eq_type, obj1_id, obj2_id in zip(model.eq_type, model.eq_obj1id, model.eq_obj2id, strict=False):
- # if eq_type != mujoco_py.const.EQ_WELD:
- if eq_type != mujoco.mjtEq.mjEQ_WELD:
- continue
- # body2 = model.body_id2name(obj2_id)
- body2 = model_names.body_id2name[obj2_id]
- if body2 == "B0" or body2 == "B9" or body2 == "B1":
- continue
- mocap_id = model.body_mocapid[obj1_id]
- if mocap_id != -1:
- # obj1 is the mocap, obj2 is the welded body
- body_idx = obj2_id
- else:
- # obj2 is the mocap, obj1 is the welded body
- mocap_id = model.body_mocapid[obj2_id]
- body_idx = obj1_id
- assert mocap_id != -1
- data.mocap_pos[mocap_id][:] = data.xpos[body_idx]
- data.mocap_quat[mocap_id][:] = data.xquat[body_idx]
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/peg_in_box.py b/lerobot/common/envs/simxarm/simxarm/tasks/peg_in_box.py
deleted file mode 100644
index 42e41520..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/peg_in_box.py
+++ /dev/null
@@ -1,86 +0,0 @@
-import numpy as np
-
-from lerobot.common.envs.simxarm.simxarm import Base
-
-
-class PegInBox(Base):
- def __init__(self):
- super().__init__("peg_in_box")
-
- def _reset_sim(self):
- self._act_magnitude = 0
- super()._reset_sim()
- for _ in range(10):
- self._apply_action(np.array([0, 0, 0, 1], dtype=np.float32))
- self.sim.step()
-
- @property
- def box(self):
- return self.sim.data.get_site_xpos("box_site")
-
- def is_success(self):
- return np.linalg.norm(self.obj - self.box) <= 0.05
-
- def get_reward(self):
- dist_xy = np.linalg.norm(self.obj[:2] - self.box[:2])
- dist_xyz = np.linalg.norm(self.obj - self.box)
- return float(dist_xy <= 0.045) * (2 - 6 * dist_xyz) - 0.2 * np.square(self._act_magnitude) - dist_xy
-
- def _get_obs(self):
- eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt
- gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint")
- eef, box = self.eef - self.center_of_table, self.box - self.center_of_table
-
- obj = self.obj - self.center_of_table
- obj_rot = self.sim.data.get_joint_qpos("object_joint0")[-4:]
- obj_velp = self.sim.data.get_site_xvelp("object_site") * self.dt
- obj_velr = self.sim.data.get_site_xvelr("object_site") * self.dt
-
- obs = np.concatenate(
- [
- eef,
- eef_velp,
- box,
- obj,
- obj_rot,
- obj_velp,
- obj_velr,
- eef - box,
- eef - obj,
- obj - box,
- np.array(
- [
- np.linalg.norm(eef - box),
- np.linalg.norm(eef - obj),
- np.linalg.norm(obj - box),
- gripper_angle,
- ]
- ),
- ],
- axis=0,
- )
- return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": box}
-
- def _sample_goal(self):
- # Gripper
- gripper_pos = np.array([1.280, 0.295, 0.9]) + self.np_random.uniform(-0.05, 0.05, size=3)
- super()._set_gripper(gripper_pos, self.gripper_rotation)
-
- # Object
- object_pos = gripper_pos - np.array([0, 0, 0.06]) + self.np_random.uniform(-0.005, 0.005, size=3)
- object_qpos = self.sim.data.get_joint_qpos("object_joint0")
- object_qpos[:3] = object_pos
- self.sim.data.set_joint_qpos("object_joint0", object_qpos)
-
- # Box
- box_pos = np.array([1.61, 0.18, 0.58])
- box_pos[:2] += self.np_random.uniform(-0.11, 0.11, size=2)
- box_qpos = self.sim.data.get_joint_qpos("box_joint0")
- box_qpos[:3] = box_pos
- self.sim.data.set_joint_qpos("box_joint0", box_qpos)
-
- return self.box
-
- def step(self, action):
- self._act_magnitude = np.linalg.norm(action[:3])
- return super().step(action)
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/push.py b/lerobot/common/envs/simxarm/simxarm/tasks/push.py
deleted file mode 100644
index 36c4a550..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/push.py
+++ /dev/null
@@ -1,78 +0,0 @@
-import numpy as np
-
-from lerobot.common.envs.simxarm.simxarm import Base
-
-
-class Push(Base):
- def __init__(self):
- super().__init__("push")
-
- def _reset_sim(self):
- self._act_magnitude = 0
- super()._reset_sim()
-
- def is_success(self):
- return np.linalg.norm(self.obj - self.goal) <= 0.05
-
- def get_reward(self):
- dist = np.linalg.norm(self.obj - self.goal)
- penalty = self._act_magnitude**2
- return -(dist + 0.15 * penalty)
-
- def _get_obs(self):
- eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt
- gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint")
- eef, goal = self.eef - self.center_of_table, self.goal - self.center_of_table
-
- obj = self.obj - self.center_of_table
- obj_rot = self.sim.data.get_joint_qpos("object_joint0")[-4:]
- obj_velp = self.sim.data.get_site_xvelp("object_site") * self.dt
- obj_velr = self.sim.data.get_site_xvelr("object_site") * self.dt
-
- obs = np.concatenate(
- [
- eef,
- eef_velp,
- goal,
- obj,
- obj_rot,
- obj_velp,
- obj_velr,
- eef - goal,
- eef - obj,
- obj - goal,
- np.array(
- [
- np.linalg.norm(eef - goal),
- np.linalg.norm(eef - obj),
- np.linalg.norm(obj - goal),
- gripper_angle,
- ]
- ),
- ],
- axis=0,
- )
- return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": goal}
-
- def _sample_goal(self):
- # Gripper
- gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3)
- super()._set_gripper(gripper_pos, self.gripper_rotation)
-
- # Object
- object_pos = self.center_of_table - np.array([0.25, 0, 0.07])
- object_pos[0] += self.np_random.uniform(-0.08, 0.08, size=1)
- object_pos[1] += self.np_random.uniform(-0.08, 0.08, size=1)
- object_qpos = self.sim.data.get_joint_qpos("object_joint0")
- object_qpos[:3] = object_pos
- self.sim.data.set_joint_qpos("object_joint0", object_qpos)
-
- # Goal
- self.goal = np.array([1.600, 0.200, 0.545])
- self.goal[:2] += self.np_random.uniform(-0.1, 0.1, size=2)
- self.sim.model.site_pos[self.sim.model.site_name2id("target0")] = self.goal
- return self.goal
-
- def step(self, action):
- self._act_magnitude = np.linalg.norm(action[:3])
- return super().step(action)
diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/reach.py b/lerobot/common/envs/simxarm/simxarm/tasks/reach.py
deleted file mode 100644
index 941a586f..00000000
--- a/lerobot/common/envs/simxarm/simxarm/tasks/reach.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import numpy as np
-
-from lerobot.common.envs.simxarm.simxarm import Base
-
-
-class Reach(Base):
- def __init__(self):
- super().__init__("reach")
-
- def _reset_sim(self):
- self._act_magnitude = 0
- super()._reset_sim()
-
- def is_success(self):
- return np.linalg.norm(self.eef - self.goal) <= 0.05
-
- def get_reward(self):
- dist = np.linalg.norm(self.eef - self.goal)
- penalty = self._act_magnitude**2
- return -(dist + 0.15 * penalty)
-
- def _get_obs(self):
- eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt
- gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint")
- eef, goal = self.eef - self.center_of_table, self.goal - self.center_of_table
- obs = np.concatenate(
- [eef, eef_velp, goal, eef - goal, np.array([np.linalg.norm(eef - goal), gripper_angle])], axis=0
- )
- return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": goal}
-
- def _sample_goal(self):
- # Gripper
- gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3)
- super()._set_gripper(gripper_pos, self.gripper_rotation)
-
- # Goal
- self.goal = np.array([1.550, 0.287, 0.580])
- self.goal[:2] += self.np_random.uniform(-0.125, 0.125, size=2)
- self.sim.model.site_pos[self.sim.model.site_name2id("target0")] = self.goal
- return self.goal
-
- def step(self, action):
- self._act_magnitude = np.linalg.norm(action[:3])
- return super().step(action)
diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py
new file mode 100644
index 00000000..4d31ddb2
--- /dev/null
+++ b/lerobot/common/envs/utils.py
@@ -0,0 +1,41 @@
+import einops
+import torch
+
+from lerobot.common.transforms import apply_inverse_transform
+
+
+def preprocess_observation(observation, transform=None):
+ # map to expected inputs for the policy
+ obs = {}
+
+ if isinstance(observation["pixels"], dict):
+ imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()}
+ else:
+ imgs = {"observation.image": observation["pixels"]}
+
+ for imgkey, img in imgs.items():
+ img = torch.from_numpy(img).float()
+ # convert to (b c h w) torch format
+ img = einops.rearrange(img, "b h w c -> b c h w")
+ obs[imgkey] = img
+
+ # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
+ obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
+
+ # apply same transforms as in training
+ if transform is not None:
+ for key in obs:
+ obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])
+
+ return obs
+
+
+def postprocess_action(action, transform=None):
+ action = action.to("cpu")
+ # action is a batch (num_env,action_dim) instead of an item (action_dim),
+ # we assume applying inverse transform on a batch works the same
+ action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
+ assert (
+ action.ndim == 2
+ ), "we assume dimensions are respectively the number of parallel envs, action dimensions"
+ return action
diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py
deleted file mode 100644
index 6dc72bef..00000000
--- a/lerobot/common/policies/abstract.py
+++ /dev/null
@@ -1,82 +0,0 @@
-from collections import deque
-
-import torch
-from torch import Tensor, nn
-
-
-class AbstractPolicy(nn.Module):
- """Base policy which all policies should be derived from.
-
- The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
- documentation for more information.
-
- Note:
- When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
- 1. set the required class attributes:
- - for classes inheriting from `AbstractDataset`: `available_datasets`
- - for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- - for classes inheriting from `AbstractPolicy`: `name`
- 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- 3. update variables in `tests/test_available.py` by importing your new class
- """
-
- name: str | None = None # same name should be used to instantiate the policy in factory.py
-
- def __init__(self, n_action_steps: int | None):
- """
- n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
- action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
- adds that dimension.
- """
- super().__init__()
- assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
- self.n_action_steps = n_action_steps
- self.clear_action_queue()
-
- def update(self, replay_buffer, step):
- """One step of the policy's learning algorithm."""
- raise NotImplementedError("Abstract method")
-
- def save(self, fp):
- torch.save(self.state_dict(), fp)
-
- def load(self, fp):
- d = torch.load(fp)
- self.load_state_dict(d)
-
- def select_actions(self, observation) -> Tensor:
- """Select an action (or trajectory of actions) based on an observation during rollout.
-
- If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
- actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
- """
- raise NotImplementedError("Abstract method")
-
- def clear_action_queue(self):
- """This should be called whenever the environment is reset."""
- if self.n_action_steps is not None:
- self._action_queue = deque([], maxlen=self.n_action_steps)
-
- def forward(self, *args, **kwargs) -> Tensor:
- """Inference step that makes multi-step policies compatible with their single-step environments.
-
- WARNING: In general, this should not be overriden.
-
- Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
- into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
- observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
- observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
- the subclass doesn't have to.
-
- This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
- 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
- the action trajectory horizon and * is the action dimensions.
- 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
- """
- if self.n_action_steps is None:
- return self.select_actions(*args, **kwargs)
- if len(self._action_queue) == 0:
- # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
- # (n_action_steps, batch_size, *), hence the transpose.
- self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
- return self._action_queue.popleft()
diff --git a/lerobot/common/policies/act/backbone.py b/lerobot/common/policies/act/backbone.py
deleted file mode 100644
index 6399d339..00000000
--- a/lerobot/common/policies/act/backbone.py
+++ /dev/null
@@ -1,115 +0,0 @@
-from typing import List
-
-import torch
-import torchvision
-from torch import nn
-from torchvision.models._utils import IntermediateLayerGetter
-
-from .position_encoding import build_position_encoding
-from .utils import NestedTensor, is_main_process
-
-
-class FrozenBatchNorm2d(torch.nn.Module):
- """
- BatchNorm2d where the batch statistics and the affine parameters are fixed.
-
- Copy-paste from torchvision.misc.ops with added eps before rqsrt,
- without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
- produce nans.
- """
-
- def __init__(self, n):
- super().__init__()
- self.register_buffer("weight", torch.ones(n))
- self.register_buffer("bias", torch.zeros(n))
- self.register_buffer("running_mean", torch.zeros(n))
- self.register_buffer("running_var", torch.ones(n))
-
- def _load_from_state_dict(
- self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
- ):
- num_batches_tracked_key = prefix + "num_batches_tracked"
- if num_batches_tracked_key in state_dict:
- del state_dict[num_batches_tracked_key]
-
- super()._load_from_state_dict(
- state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
- )
-
- def forward(self, x):
- # move reshapes to the beginning
- # to make it fuser-friendly
- w = self.weight.reshape(1, -1, 1, 1)
- b = self.bias.reshape(1, -1, 1, 1)
- rv = self.running_var.reshape(1, -1, 1, 1)
- rm = self.running_mean.reshape(1, -1, 1, 1)
- eps = 1e-5
- scale = w * (rv + eps).rsqrt()
- bias = b - rm * scale
- return x * scale + bias
-
-
-class BackboneBase(nn.Module):
- def __init__(
- self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool
- ):
- super().__init__()
- # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
- # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
- # parameter.requires_grad_(False)
- if return_interm_layers:
- return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
- else:
- return_layers = {"layer4": "0"}
- self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
- self.num_channels = num_channels
-
- def forward(self, tensor):
- xs = self.body(tensor)
- return xs
- # out: Dict[str, NestedTensor] = {}
- # for name, x in xs.items():
- # m = tensor_list.mask
- # assert m is not None
- # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
- # out[name] = NestedTensor(x, mask)
- # return out
-
-
-class Backbone(BackboneBase):
- """ResNet backbone with frozen BatchNorm."""
-
- def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
- backbone = getattr(torchvision.models, name)(
- replace_stride_with_dilation=[False, False, dilation],
- pretrained=is_main_process(),
- norm_layer=FrozenBatchNorm2d,
- ) # pretrained # TODO do we want frozen batch_norm??
- num_channels = 512 if name in ("resnet18", "resnet34") else 2048
- super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
-
-
-class Joiner(nn.Sequential):
- def __init__(self, backbone, position_embedding):
- super().__init__(backbone, position_embedding)
-
- def forward(self, tensor_list: NestedTensor):
- xs = self[0](tensor_list)
- out: List[NestedTensor] = []
- pos = []
- for _, x in xs.items():
- out.append(x)
- # position encoding
- pos.append(self[1](x).to(x.dtype))
-
- return out, pos
-
-
-def build_backbone(args):
- position_embedding = build_position_encoding(args)
- train_backbone = args.lr_backbone > 0
- return_interm_layers = args.masks
- backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
- model = Joiner(backbone, position_embedding)
- model.num_channels = backbone.num_channels
- return model
diff --git a/lerobot/common/policies/act/detr_vae.py b/lerobot/common/policies/act/detr_vae.py
deleted file mode 100644
index 0f2626f7..00000000
--- a/lerobot/common/policies/act/detr_vae.py
+++ /dev/null
@@ -1,212 +0,0 @@
-import numpy as np
-import torch
-from torch import nn
-from torch.autograd import Variable
-
-from .backbone import build_backbone
-from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer
-
-
-def reparametrize(mu, logvar):
- std = logvar.div(2).exp()
- eps = Variable(std.data.new(std.size()).normal_())
- return mu + std * eps
-
-
-def get_sinusoid_encoding_table(n_position, d_hid):
- def get_position_angle_vec(position):
- return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
-
- sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
- sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
- sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
-
- return torch.FloatTensor(sinusoid_table).unsqueeze(0)
-
-
-class DETRVAE(nn.Module):
- """This is the DETR module that performs object detection"""
-
- def __init__(
- self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae
- ):
- """Initializes the model.
- Parameters:
- backbones: torch module of the backbone to be used. See backbone.py
- transformer: torch module of the transformer architecture. See transformer.py
- state_dim: robot state dimension of the environment
- num_queries: number of object queries, ie detection slot. This is the maximal number of objects
- DETR can detect in a single image. For COCO, we recommend 100 queries.
- aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
- """
- super().__init__()
- self.num_queries = num_queries
- self.camera_names = camera_names
- self.transformer = transformer
- self.encoder = encoder
- self.vae = vae
- hidden_dim = transformer.d_model
- self.action_head = nn.Linear(hidden_dim, action_dim)
- self.is_pad_head = nn.Linear(hidden_dim, 1)
- self.query_embed = nn.Embedding(num_queries, hidden_dim)
- if backbones is not None:
- self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
- self.backbones = nn.ModuleList(backbones)
- self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
- else:
- # input_dim = 14 + 7 # robot_state + env_state
- self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
- # TODO(rcadene): understand what is env_state, and why it needs to be 7
- self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim)
- self.pos = torch.nn.Embedding(2, hidden_dim)
- self.backbones = None
-
- # encoder extra parameters
- self.latent_dim = 32 # final size of latent z # TODO tune
- self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
- self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
- self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
- self.latent_proj = nn.Linear(
- hidden_dim, self.latent_dim * 2
- ) # project hidden state to latent std, var
- self.register_buffer(
- "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
- ) # [CLS], qpos, a_seq
-
- # decoder extra parameters
- self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
- self.additional_pos_embed = nn.Embedding(
- 2, hidden_dim
- ) # learned position embedding for proprio and latent
-
- def forward(self, qpos, image, env_state, actions=None, is_pad=None):
- """
- qpos: batch, qpos_dim
- image: batch, num_cam, channel, height, width
- env_state: None
- actions: batch, seq, action_dim
- """
- is_training = actions is not None # train or val
- bs, _ = qpos.shape
- ### Obtain latent z from action sequence
- if self.vae and is_training:
- # project action sequence to embedding dim, and concat with a CLS token
- action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
- qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
- qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
- cls_embed = self.cls_embed.weight # (1, hidden_dim)
- cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
- encoder_input = torch.cat(
- [cls_embed, qpos_embed, action_embed], axis=1
- ) # (bs, seq+1, hidden_dim)
- encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
- # do not mask cls token
- # cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
- # is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
- # obtain position embedding
- pos_embed = self.pos_table.clone().detach()
- pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
- # query model
- encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad)
- encoder_output = encoder_output[0] # take cls output only
- latent_info = self.latent_proj(encoder_output)
- mu = latent_info[:, : self.latent_dim]
- logvar = latent_info[:, self.latent_dim :]
- latent_sample = reparametrize(mu, logvar)
- latent_input = self.latent_out_proj(latent_sample)
- else:
- mu = logvar = None
- latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
- latent_input = self.latent_out_proj(latent_sample)
-
- if self.backbones is not None:
- # Image observation features and position embeddings
- all_cam_features = []
- all_cam_pos = []
- for cam_id, _ in enumerate(self.camera_names):
- features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
- features = features[0] # take the last layer feature
- pos = pos[0]
- all_cam_features.append(self.input_proj(features))
- all_cam_pos.append(pos)
- # proprioception features
- proprio_input = self.input_proj_robot_state(qpos)
- # fold camera dimension into width dimension
- src = torch.cat(all_cam_features, axis=3)
- pos = torch.cat(all_cam_pos, axis=3)
- hs = self.transformer(
- src,
- None,
- self.query_embed.weight,
- pos,
- latent_input,
- proprio_input,
- self.additional_pos_embed.weight,
- )[0]
- else:
- qpos = self.input_proj_robot_state(qpos)
- env_state = self.input_proj_env_state(env_state)
- transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
- hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
- a_hat = self.action_head(hs)
- is_pad_hat = self.is_pad_head(hs)
- return a_hat, is_pad_hat, [mu, logvar]
-
-
-def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
- if hidden_depth == 0:
- mods = [nn.Linear(input_dim, output_dim)]
- else:
- mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
- for _ in range(hidden_depth - 1):
- mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
- mods.append(nn.Linear(hidden_dim, output_dim))
- trunk = nn.Sequential(*mods)
- return trunk
-
-
-def build_encoder(args):
- d_model = args.hidden_dim # 256
- dropout = args.dropout # 0.1
- nhead = args.nheads # 8
- dim_feedforward = args.dim_feedforward # 2048
- num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
- normalize_before = args.pre_norm # False
- activation = "relu"
-
- encoder_layer = TransformerEncoderLayer(
- d_model, nhead, dim_feedforward, dropout, activation, normalize_before
- )
- encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
- encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
-
- return encoder
-
-
-def build(args):
- # From state
- # backbone = None # from state for now, no need for conv nets
- # From image
- backbones = []
- backbone = build_backbone(args)
- backbones.append(backbone)
-
- transformer = build_transformer(args)
-
- encoder = build_encoder(args)
-
- model = DETRVAE(
- backbones,
- transformer,
- encoder,
- state_dim=args.state_dim,
- action_dim=args.action_dim,
- num_queries=args.num_queries,
- camera_names=args.camera_names,
- vae=args.vae,
- )
-
- n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
-
- return model
diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py
index ae4f7320..25b814ed 100644
--- a/lerobot/common/policies/act/policy.py
+++ b/lerobot/common/policies/act/policy.py
@@ -1,173 +1,200 @@
-import logging
-import time
+"""Action Chunking Transformer Policy
+As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
+The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
+"""
+
+import math
+import time
+from collections import deque
+from itertools import chain
+from typing import Callable
+
+import einops
+import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
+import torchvision
import torchvision.transforms as transforms
+from torch import Tensor, nn
+from torchvision.models._utils import IntermediateLayerGetter
+from torchvision.ops.misc import FrozenBatchNorm2d
-from lerobot.common.policies.abstract import AbstractPolicy
-from lerobot.common.policies.act.detr_vae import build
from lerobot.common.utils import get_safe_torch_device
-def build_act_model_and_optimizer(cfg):
- model = build(cfg)
+class ActionChunkingTransformerPolicy(nn.Module):
+ """
+ Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
+ Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
- param_dicts = [
- {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
- {
- "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
- "lr": cfg.lr_backbone,
- },
- ]
- optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
+ Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
+ - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
+ model that encodes the target data (a sequence of actions), and the condition (the robot
+ joint-space).
+ - A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
+ cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
+ have an option to train this model without the variational objective (in which case we drop the
+ `vae_encoder` altogether, and nothing about this model has anything to do with a VAE).
- return model, optimizer
+ Transformer
+ Used alone for inference
+ (acts as VAE decoder
+ during training)
+ ┌───────────────────────┐
+ │ Outputs │
+ │ ▲ │
+ │ ┌─────►┌───────┐ │
+ ┌──────┐ │ │ │Transf.│ │
+ │ │ │ ├─────►│decoder│ │
+ ┌────┴────┐ │ │ │ │ │ │
+ │ │ │ │ ┌───┴───┬─►│ │ │
+ │ VAE │ │ │ │ │ └───────┘ │
+ │ encoder │ │ │ │Transf.│ │
+ │ │ │ │ │encoder│ │
+ └───▲─────┘ │ │ │ │ │
+ │ │ │ └───▲───┘ │
+ │ │ │ │ │
+ inputs └─────┼─────┘ │
+ │ │
+ └───────────────────────┘
+ """
-
-def kl_divergence(mu, logvar):
- batch_size = mu.size(0)
- assert batch_size != 0
- if mu.data.ndimension() == 4:
- mu = mu.view(mu.size(0), mu.size(1))
- if logvar.data.ndimension() == 4:
- logvar = logvar.view(logvar.size(0), logvar.size(1))
-
- klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
- total_kld = klds.sum(1).mean(0, True)
- dimension_wise_kld = klds.mean(0)
- mean_kld = klds.mean(1).mean(0, True)
-
- return total_kld, dimension_wise_kld, mean_kld
-
-
-class ActionChunkingTransformerPolicy(AbstractPolicy):
name = "act"
+ _multiple_obs_steps_not_handled_msg = (
+ "ActionChunkingTransformerPolicy does not handle multiple observation steps."
+ )
- def __init__(self, cfg, device, n_action_steps=1):
- super().__init__(n_action_steps)
+ def __init__(self, cfg, device):
+ """
+ TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
+ """
+ super().__init__()
+ if getattr(cfg, "n_obs_steps", 1) != 1:
+ raise ValueError(self._multiple_obs_steps_not_handled_msg)
self.cfg = cfg
- self.n_action_steps = n_action_steps
+ self.n_action_steps = cfg.n_action_steps
self.device = get_safe_torch_device(device)
- self.model, self.optimizer = build_act_model_and_optimizer(cfg)
- self.kl_weight = self.cfg.kl_weight
- logging.info(f"KL Weight {self.kl_weight}")
+ self.camera_names = cfg.camera_names
+ self.use_vae = cfg.use_vae
+ self.horizon = cfg.horizon
+ self.d_model = cfg.d_model
+
+ transformer_common_kwargs = dict( # noqa: C408
+ d_model=self.d_model,
+ num_heads=cfg.num_heads,
+ dim_feedforward=cfg.dim_feedforward,
+ dropout=cfg.dropout,
+ activation=cfg.activation,
+ normalize_before=cfg.pre_norm,
+ )
+
+ # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
+ # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
+ if self.use_vae:
+ self.vae_encoder = _TransformerEncoder(num_layers=cfg.vae_enc_layers, **transformer_common_kwargs)
+ self.vae_encoder_cls_embed = nn.Embedding(1, self.d_model)
+ # Projection layer for joint-space configuration to hidden dimension.
+ self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model)
+ # Projection layer for action (joint-space target) to hidden dimension.
+ self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, self.d_model)
+ self.latent_dim = cfg.latent_dim
+ # Projection layer from the VAE encoder's output to the latent distribution's parameter space.
+ self.vae_encoder_latent_output_proj = nn.Linear(self.d_model, self.latent_dim * 2)
+ # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
+ # dimension.
+ self.register_buffer(
+ "vae_encoder_pos_enc",
+ _create_sinusoidal_position_embedding(1 + 1 + self.horizon, self.d_model).unsqueeze(0),
+ )
+
+ # Backbone for image feature extraction.
+ self.image_normalizer = transforms.Normalize(
+ mean=cfg.image_normalization.mean, std=cfg.image_normalization.std
+ )
+ backbone_model = getattr(torchvision.models, cfg.backbone)(
+ replace_stride_with_dilation=[False, False, cfg.dilation],
+ pretrained=cfg.pretrained_backbone,
+ norm_layer=FrozenBatchNorm2d,
+ )
+ # Note: The forward method of this returns a dict: {"feature_map": output}.
+ self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
+
+ # Transformer (acts as VAE decoder when training with the variational objective).
+ self.encoder = _TransformerEncoder(num_layers=cfg.enc_layers, **transformer_common_kwargs)
+ self.decoder = _TransformerDecoder(num_layers=cfg.dec_layers, **transformer_common_kwargs)
+
+ # Transformer encoder input projections. The tokens will be structured like
+ # [latent, robot_state, image_feature_map_pixels].
+ self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model)
+ self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.d_model)
+ self.encoder_img_feat_input_proj = nn.Conv2d(
+ backbone_model.fc.in_features, self.d_model, kernel_size=1
+ )
+ # Transformer encoder positional embeddings.
+ self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, self.d_model)
+ self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(self.d_model // 2)
+
+ # Transformer decoder.
+ # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
+ self.decoder_pos_embed = nn.Embedding(self.horizon, self.d_model)
+
+ # Final action regression head on the output of the transformer's decoder.
+ self.action_head = nn.Linear(self.d_model, cfg.action_dim)
+
+ self._reset_parameters()
+
+ self._create_optimizer()
self.to(self.device)
- def update(self, replay_buffer, step):
- del step
-
- start_time = time.time()
-
- self.train()
-
- num_slices = self.cfg.batch_size
- batch_size = self.cfg.horizon * num_slices
-
- assert batch_size % self.cfg.horizon == 0
- assert batch_size % num_slices == 0
-
- def process_batch(batch, horizon, num_slices):
- # trajectory t = 64, horizon h = 16
- # (t h) ... -> t h ...
- batch = batch.reshape(num_slices, horizon)
-
- image = batch["observation", "image", "top"]
- image = image[:, 0] # first observation t=0
- # batch, num_cam, channel, height, width
- image = image.unsqueeze(1)
- assert image.ndim == 5
- image = image.float()
-
- state = batch["observation", "state"]
- state = state[:, 0] # first observation t=0
- # batch, qpos_dim
- assert state.ndim == 2
-
- action = batch["action"]
- # batch, seq, action_dim
- assert action.ndim == 3
- assert action.shape[1] == horizon
-
- if self.cfg.n_obs_steps > 1:
- raise NotImplementedError()
- # # keep first n observations of the slice corresponding to t=[-1,0]
- # image = image[:, : self.cfg.n_obs_steps]
- # state = state[:, : self.cfg.n_obs_steps]
-
- out = {
- "obs": {
- "image": image.to(self.device, non_blocking=True),
- "agent_pos": state.to(self.device, non_blocking=True),
- },
- "action": action.to(self.device, non_blocking=True),
- }
- return out
-
- batch = replay_buffer.sample(batch_size)
- batch = process_batch(batch, self.cfg.horizon, num_slices)
-
- data_s = time.time() - start_time
-
- loss = self.compute_loss(batch)
- loss.backward()
-
- grad_norm = torch.nn.utils.clip_grad_norm_(
- self.model.parameters(),
- self.cfg.grad_clip_norm,
- error_if_nonfinite=False,
+ def _create_optimizer(self):
+ optimizer_params_dicts = [
+ {
+ "params": [
+ p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad
+ ]
+ },
+ {
+ "params": [
+ p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad
+ ],
+ "lr": self.cfg.lr_backbone,
+ },
+ ]
+ self.optimizer = torch.optim.AdamW(
+ optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
)
- self.optimizer.step()
- self.optimizer.zero_grad()
- # self.lr_scheduler.step()
+ def _reset_parameters(self):
+ """Xavier-uniform initialization of the transformer parameters as in the original code."""
+ for p in chain(self.encoder.parameters(), self.decoder.parameters()):
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
- info = {
- "loss": loss.item(),
- "grad_norm": float(grad_norm),
- # "lr": self.lr_scheduler.get_last_lr()[0],
- "lr": self.cfg.lr,
- "data_s": data_s,
- "update_s": time.time() - start_time,
- }
+ def reset(self):
+ """This should be called whenever the environment is reset."""
+ if self.n_action_steps is not None:
+ self._action_queue = deque([], maxlen=self.n_action_steps)
- return info
-
- def save(self, fp):
- torch.save(self.state_dict(), fp)
-
- def load(self, fp):
- d = torch.load(fp)
- self.load_state_dict(d)
-
- def compute_loss(self, batch):
- loss_dict = self._forward(
- qpos=batch["obs"]["agent_pos"],
- image=batch["obs"]["image"],
- actions=batch["action"],
- )
- loss = loss_dict["loss"]
- return loss
+ def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor:
+ """
+ This method wraps `select_actions` in order to return one action at a time for execution in the
+ environment. It works by managing the actions in a queue and only calling `select_actions` when the
+ queue is empty.
+ """
+ if len(self._action_queue) == 0:
+ # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
+ # (n_action_steps, batch_size, *), hence the transpose.
+ self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
+ return self._action_queue.popleft()
@torch.no_grad()
- def select_actions(self, observation, step_count):
- if observation["image"].shape[0] != 1:
- raise NotImplementedError("Batch size > 1 not handled")
-
- # TODO(rcadene): remove unused step_count
- del step_count
-
+ def select_actions(self, batch: dict[str, Tensor]) -> Tensor:
+ """Use the action chunking transformer to generate a sequence of actions."""
self.eval()
+ self._preprocess_batch(batch, add_obs_steps_dim=True)
- # TODO(rcadene): remove hack
- # add 1 camera dimension
- observation["image", "top"] = observation["image", "top"].unsqueeze(1)
-
- obs_dict = {
- "image": observation["image", "top"],
- "agent_pos": observation["state"],
- }
- action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
+ action = self.forward(batch, return_loss=False)
if self.cfg.temporal_agg:
# TODO(rcadene): implement temporal aggregation
@@ -182,35 +209,470 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
- # take first predicted action or n first actions
- action = action[: self.n_action_steps]
- return action
+ return action[: self.n_action_steps]
- def _forward(self, qpos, image, actions=None, is_pad=None):
- env_state = None
- normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- image = normalize(image)
+ def __call__(self, *args, **kwargs) -> dict:
+ # TODO(now): Temporary bridge until we know what to do about the `update` method.
+ return self.update(*args, **kwargs)
- is_training = actions is not None
- if is_training: # training time
- actions = actions[:, : self.model.num_queries]
- if is_pad is not None:
- is_pad = is_pad[:, : self.model.num_queries]
+ def _preprocess_batch(
+ self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False
+ ) -> dict[str, Tensor]:
+ """
+ This function expects `batch` to have (at least):
+ {
+ "observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
+ "observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
+ "action": (B, H, J) tensor of actions (positional target for robot joint configuration)
+ "action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds.
+ }
+ """
+ if add_obs_steps_dim:
+ # Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
+ # this just amounts to an unsqueeze.
+ for k in batch:
+ if k.startswith("observation."):
+ batch[k] = batch[k].unsqueeze(1)
- a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
+ if batch["observation.state"].shape[1] != 1:
+ raise ValueError(self._multiple_obs_steps_not_handled_msg)
+ batch["observation.state"] = batch["observation.state"].squeeze(1)
+ # TODO(alexander-soare): generalize this to multiple images.
+ assert (
+ sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
+ ), "ACT only handles one image for now."
+ # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
+ # the image index dimension.
- all_l1 = F.l1_loss(actions, a_hat, reduction="none")
- l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
+ def update(self, batch, *_, **__) -> dict:
+ start_time = time.time()
+ self._preprocess_batch(batch)
+
+ self.train()
+
+ num_slices = self.cfg.batch_size
+ batch_size = self.cfg.horizon * num_slices
+
+ assert batch_size % self.cfg.horizon == 0
+ assert batch_size % num_slices == 0
+
+ loss = self.forward(batch, return_loss=True)["loss"]
+ loss.backward()
+
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ self.parameters(),
+ self.cfg.grad_clip_norm,
+ error_if_nonfinite=False,
+ )
+
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+
+ info = {
+ "loss": loss.item(),
+ "grad_norm": float(grad_norm),
+ "lr": self.cfg.lr,
+ "update_s": time.time() - start_time,
+ }
+
+ return info
+
+ def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
+ images = self.image_normalizer(batch["observation.images.top"])
+
+ if return_loss: # training time
+ actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
+ batch["observation.state"], images, batch["action"]
+ )
+
+ l1_loss = (
+ F.l1_loss(batch["action"], actions_hat, reduction="none")
+ * ~batch["action_is_pad"].unsqueeze(-1)
+ ).mean()
loss_dict = {}
- loss_dict["l1"] = l1
- if self.cfg.vae:
- total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
- loss_dict["kl"] = total_kld[0]
- loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
+ loss_dict["l1"] = l1_loss
+ if self.cfg.use_vae:
+ # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
+ # each dimension independently, we sum over the latent dimension to get the total
+ # KL-divergence per batch element, then take the mean over the batch.
+ # (See App. B of https://arxiv.org/abs/1312.6114 for more details).
+ mean_kld = (
+ (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
+ )
+ loss_dict["kl"] = mean_kld
+ loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
else:
loss_dict["loss"] = loss_dict["l1"]
return loss_dict
else:
- action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
+ action, _ = self._forward(batch["observation.state"], images)
return action
+
+ def _forward(
+ self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
+ ) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
+ """
+ Args:
+ robot_state: (B, J) batch of robot joint configurations.
+ image: (B, N, C, H, W) batch of N camera frames.
+ actions: (B, S, A) batch of actions from the target dataset which must be provided if the
+ VAE is enabled and the model is in training mode.
+ Returns:
+ (B, S, A) batch of action sequences
+ Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
+ latent dimension.
+ """
+ if self.use_vae and self.training:
+ assert (
+ actions is not None
+ ), "actions must be provided when using the variational objective in training mode."
+
+ batch_size = robot_state.shape[0]
+
+ # Prepare the latent for input to the transformer encoder.
+ if self.use_vae and actions is not None:
+ # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
+ cls_embed = einops.repeat(
+ self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
+ ) # (B, 1, D)
+ robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
+ action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
+ vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
+
+ # Prepare fixed positional embedding.
+ # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
+ pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
+
+ # Forward pass through VAE encoder to get the latent PDF parameters.
+ cls_token_out = self.vae_encoder(
+ vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
+ )[0] # select the class token, with shape (B, D)
+ latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
+ mu = latent_pdf_params[:, : self.latent_dim]
+ # This is 2log(sigma). Done this way to match the original implementation.
+ log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
+
+ # Sample the latent with the reparameterization trick.
+ latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
+ else:
+ # When not using the VAE encoder, we set the latent to be all zeros.
+ mu = log_sigma_x2 = None
+ latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
+ robot_state.device
+ )
+
+ # Prepare all other transformer encoder inputs.
+ # Camera observation features and positional embeddings.
+ all_cam_features = []
+ all_cam_pos_embeds = []
+ for cam_id, _ in enumerate(self.camera_names):
+ cam_features = self.backbone(image[:, cam_id])["feature_map"]
+ cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
+ cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
+ all_cam_features.append(cam_features)
+ all_cam_pos_embeds.append(cam_pos_embed)
+ # Concatenate camera observation feature maps and positional embeddings along the width dimension.
+ encoder_in = torch.cat(all_cam_features, axis=3)
+ cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
+
+ # Get positional embeddings for robot state and latent.
+ robot_state_embed = self.encoder_robot_state_input_proj(robot_state)
+ latent_embed = self.encoder_latent_input_proj(latent_sample)
+
+ # Stack encoder input and positional embeddings moving to (S, B, C).
+ encoder_in = torch.cat(
+ [
+ torch.stack([latent_embed, robot_state_embed], axis=0),
+ encoder_in.flatten(2).permute(2, 0, 1),
+ ]
+ )
+ pos_embed = torch.cat(
+ [
+ self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1),
+ cam_pos_embed.flatten(2).permute(2, 0, 1),
+ ],
+ axis=0,
+ )
+
+ # Forward pass through the transformer modules.
+ encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
+ decoder_in = torch.zeros(
+ (self.horizon, batch_size, self.d_model), dtype=pos_embed.dtype, device=pos_embed.device
+ )
+ decoder_out = self.decoder(
+ decoder_in,
+ encoder_out,
+ encoder_pos_embed=pos_embed,
+ decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
+ )
+
+ # Move back to (B, S, C).
+ decoder_out = decoder_out.transpose(0, 1)
+
+ actions = self.action_head(decoder_out)
+
+ return actions, (mu, log_sigma_x2)
+
+ def save(self, fp):
+ torch.save(self.state_dict(), fp)
+
+ def load(self, fp):
+ d = torch.load(fp)
+ self.load_state_dict(d)
+
+
+class _TransformerEncoder(nn.Module):
+ """Convenience module for running multiple encoder layers, maybe followed by normalization."""
+
+ def __init__(self, num_layers: int, **encoder_layer_kwargs: dict):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [_TransformerEncoderLayer(**encoder_layer_kwargs) for _ in range(num_layers)]
+ )
+ self.norm = (
+ nn.LayerNorm(encoder_layer_kwargs["d_model"])
+ if encoder_layer_kwargs["normalize_before"]
+ else nn.Identity()
+ )
+
+ def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
+ for layer in self.layers:
+ x = layer(x, pos_embed=pos_embed)
+ x = self.norm(x)
+ return x
+
+
+class _TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ num_heads: int,
+ dim_feedforward: int,
+ dropout: float,
+ activation: str,
+ normalize_before: bool,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
+
+ # Feed forward layers.
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
+ skip = x
+ if self.normalize_before:
+ x = self.norm1(x)
+ q = k = x if pos_embed is None else x + pos_embed
+ x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
+ x = skip + self.dropout1(x)
+ if self.normalize_before:
+ skip = x
+ x = self.norm2(x)
+ else:
+ x = self.norm1(x)
+ skip = x
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
+ x = skip + self.dropout2(x)
+ if not self.normalize_before:
+ x = self.norm2(x)
+ return x
+
+
+class _TransformerDecoder(nn.Module):
+ def __init__(self, num_layers: int, **decoder_layer_kwargs):
+ """Convenience module for running multiple decoder layers followed by normalization."""
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [_TransformerDecoderLayer(**decoder_layer_kwargs) for _ in range(num_layers)]
+ )
+ self.num_layers = num_layers
+ self.norm = nn.LayerNorm(decoder_layer_kwargs["d_model"])
+
+ def forward(
+ self,
+ x: Tensor,
+ encoder_out: Tensor,
+ decoder_pos_embed: Tensor | None = None,
+ encoder_pos_embed: Tensor | None = None,
+ ) -> Tensor:
+ for layer in self.layers:
+ x = layer(
+ x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
+ )
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+
+class _TransformerDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ num_heads: int,
+ dim_feedforward: int,
+ dropout: float,
+ activation: str,
+ normalize_before: bool,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
+
+ # Feed forward layers.
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
+ return tensor if pos_embed is None else tensor + pos_embed
+
+ def forward(
+ self,
+ x: Tensor,
+ encoder_out: Tensor,
+ decoder_pos_embed: Tensor | None = None,
+ encoder_pos_embed: Tensor | None = None,
+ ) -> Tensor:
+ """
+ Args:
+ x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
+ encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
+ cross-attending with.
+ decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
+ encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
+ Returns:
+ (DS, B, C) tensor of decoder output features.
+ """
+ skip = x
+ if self.normalize_before:
+ x = self.norm1(x)
+ q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
+ x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
+ x = skip + self.dropout1(x)
+ if self.normalize_before:
+ skip = x
+ x = self.norm2(x)
+ else:
+ x = self.norm1(x)
+ skip = x
+ x = self.multihead_attn(
+ query=self.maybe_add_pos_embed(x, decoder_pos_embed),
+ key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
+ value=encoder_out,
+ )[0] # select just the output, not the attention weights
+ x = skip + self.dropout2(x)
+ if self.normalize_before:
+ skip = x
+ x = self.norm3(x)
+ else:
+ x = self.norm2(x)
+ skip = x
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
+ x = skip + self.dropout3(x)
+ if not self.normalize_before:
+ x = self.norm3(x)
+ return x
+
+
+def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor:
+ """1D sinusoidal positional embeddings as in Attention is All You Need.
+
+ Args:
+ num_positions: Number of token positions required.
+ Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension).
+
+ """
+
+ def get_position_angle_vec(position):
+ return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
+
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+ return torch.from_numpy(sinusoid_table).float()
+
+
+class _SinusoidalPositionEmbedding2D(nn.Module):
+ """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
+
+ The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
+ for the vertical direction, and 1/W for the horizontal direction.
+ """
+
+ def __init__(self, dimension: int):
+ """
+ Args:
+ dimension: The desired dimension of the embeddings.
+ """
+ super().__init__()
+ self.dimension = dimension
+ self._two_pi = 2 * math.pi
+ self._eps = 1e-6
+ # Inverse "common ratio" for the geometric progression in sinusoid frequencies.
+ self._temperature = 10000
+
+ def forward(self, x: Tensor) -> Tensor:
+ """
+ Args:
+ x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for.
+ Returns:
+ A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
+ """
+ not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
+ # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
+ # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
+ y_range = not_mask.cumsum(1, dtype=torch.float32)
+ x_range = not_mask.cumsum(2, dtype=torch.float32)
+
+ # "Normalize" the position index such that it ranges in [0, 2π].
+ # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
+ # are non-zero by construction. This is an artifact of the original code.
+ y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
+ x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
+
+ inverse_frequency = self._temperature ** (
+ 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
+ )
+
+ x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
+ y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
+
+ # Note: this stack then flatten operation results in interleaved sine and cosine terms.
+ # pos_embed_x and pos_embed_y are (1, H, W, C // 2).
+ pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
+ pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
+ pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
+
+ return pos_embed
+
+
+def _get_activation_fn(activation: str) -> Callable:
+ """Return an activation function given a string."""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
diff --git a/lerobot/common/policies/act/position_encoding.py b/lerobot/common/policies/act/position_encoding.py
deleted file mode 100644
index 63bb4840..00000000
--- a/lerobot/common/policies/act/position_encoding.py
+++ /dev/null
@@ -1,102 +0,0 @@
-"""
-Various positional encodings for the transformer.
-"""
-
-import math
-
-import torch
-from torch import nn
-
-from .utils import NestedTensor
-
-
-class PositionEmbeddingSine(nn.Module):
- """
- This is a more standard version of the position embedding, very similar to the one
- used by the Attention is all you need paper, generalized to work on images.
- """
-
- def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
- super().__init__()
- self.num_pos_feats = num_pos_feats
- self.temperature = temperature
- self.normalize = normalize
- if scale is not None and normalize is False:
- raise ValueError("normalize should be True if scale is passed")
- if scale is None:
- scale = 2 * math.pi
- self.scale = scale
-
- def forward(self, tensor):
- x = tensor
- # mask = tensor_list.mask
- # assert mask is not None
- # not_mask = ~mask
-
- not_mask = torch.ones_like(x[0, [0]])
- y_embed = not_mask.cumsum(1, dtype=torch.float32)
- x_embed = not_mask.cumsum(2, dtype=torch.float32)
- if self.normalize:
- eps = 1e-6
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
-
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
- dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
-
- pos_x = x_embed[:, :, :, None] / dim_t
- pos_y = y_embed[:, :, :, None] / dim_t
- pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
- pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
- return pos
-
-
-class PositionEmbeddingLearned(nn.Module):
- """
- Absolute pos embedding, learned.
- """
-
- def __init__(self, num_pos_feats=256):
- super().__init__()
- self.row_embed = nn.Embedding(50, num_pos_feats)
- self.col_embed = nn.Embedding(50, num_pos_feats)
- self.reset_parameters()
-
- def reset_parameters(self):
- nn.init.uniform_(self.row_embed.weight)
- nn.init.uniform_(self.col_embed.weight)
-
- def forward(self, tensor_list: NestedTensor):
- x = tensor_list.tensors
- h, w = x.shape[-2:]
- i = torch.arange(w, device=x.device)
- j = torch.arange(h, device=x.device)
- x_emb = self.col_embed(i)
- y_emb = self.row_embed(j)
- pos = (
- torch.cat(
- [
- x_emb.unsqueeze(0).repeat(h, 1, 1),
- y_emb.unsqueeze(1).repeat(1, w, 1),
- ],
- dim=-1,
- )
- .permute(2, 0, 1)
- .unsqueeze(0)
- .repeat(x.shape[0], 1, 1, 1)
- )
- return pos
-
-
-def build_position_encoding(args):
- n_steps = args.hidden_dim // 2
- if args.position_embedding in ("v2", "sine"):
- # TODO find a better way of exposing other arguments
- position_embedding = PositionEmbeddingSine(n_steps, normalize=True)
- elif args.position_embedding in ("v3", "learned"):
- position_embedding = PositionEmbeddingLearned(n_steps)
- else:
- raise ValueError(f"not supported {args.position_embedding}")
-
- return position_embedding
diff --git a/lerobot/common/policies/act/transformer.py b/lerobot/common/policies/act/transformer.py
deleted file mode 100644
index 20cfc815..00000000
--- a/lerobot/common/policies/act/transformer.py
+++ /dev/null
@@ -1,371 +0,0 @@
-"""
-DETR Transformer class.
-
-Copy-paste from torch.nn.Transformer with modifications:
- * positional encodings are passed in MHattention
- * extra LN at the end of encoder is removed
- * decoder returns a stack of activations from all decoding layers
-"""
-
-import copy
-from typing import Optional
-
-import torch
-import torch.nn.functional as F # noqa: N812
-from torch import Tensor, nn
-
-
-class Transformer(nn.Module):
- def __init__(
- self,
- d_model=512,
- nhead=8,
- num_encoder_layers=6,
- num_decoder_layers=6,
- dim_feedforward=2048,
- dropout=0.1,
- activation="relu",
- normalize_before=False,
- return_intermediate_dec=False,
- ):
- super().__init__()
-
- encoder_layer = TransformerEncoderLayer(
- d_model, nhead, dim_feedforward, dropout, activation, normalize_before
- )
- encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
- self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
-
- decoder_layer = TransformerDecoderLayer(
- d_model, nhead, dim_feedforward, dropout, activation, normalize_before
- )
- decoder_norm = nn.LayerNorm(d_model)
- self.decoder = TransformerDecoder(
- decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec
- )
-
- self._reset_parameters()
-
- self.d_model = d_model
- self.nhead = nhead
-
- def _reset_parameters(self):
- for p in self.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
-
- def forward(
- self,
- src,
- mask,
- query_embed,
- pos_embed,
- latent_input=None,
- proprio_input=None,
- additional_pos_embed=None,
- ):
- # TODO flatten only when input has H and W
- if len(src.shape) == 4: # has H and W
- # flatten NxCxHxW to HWxNxC
- bs, c, h, w = src.shape
- src = src.flatten(2).permute(2, 0, 1)
- pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
- query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
- # mask = mask.flatten(1)
-
- additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
- pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
-
- addition_input = torch.stack([latent_input, proprio_input], axis=0)
- src = torch.cat([addition_input, src], axis=0)
- else:
- assert len(src.shape) == 3
- # flatten NxHWxC to HWxNxC
- bs, hw, c = src.shape
- src = src.permute(1, 0, 2)
- pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
- query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
-
- tgt = torch.zeros_like(query_embed)
- memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
- hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
- hs = hs.transpose(1, 2)
- return hs
-
-
-class TransformerEncoder(nn.Module):
- def __init__(self, encoder_layer, num_layers, norm=None):
- super().__init__()
- self.layers = _get_clones(encoder_layer, num_layers)
- self.num_layers = num_layers
- self.norm = norm
-
- def forward(
- self,
- src,
- mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- ):
- output = src
-
- for layer in self.layers:
- output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
-
- if self.norm is not None:
- output = self.norm(output)
-
- return output
-
-
-class TransformerDecoder(nn.Module):
- def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
- super().__init__()
- self.layers = _get_clones(decoder_layer, num_layers)
- self.num_layers = num_layers
- self.norm = norm
- self.return_intermediate = return_intermediate
-
- def forward(
- self,
- tgt,
- memory,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None,
- ):
- output = tgt
-
- intermediate = []
-
- for layer in self.layers:
- output = layer(
- output,
- memory,
- tgt_mask=tgt_mask,
- memory_mask=memory_mask,
- tgt_key_padding_mask=tgt_key_padding_mask,
- memory_key_padding_mask=memory_key_padding_mask,
- pos=pos,
- query_pos=query_pos,
- )
- if self.return_intermediate:
- intermediate.append(self.norm(output))
-
- if self.norm is not None:
- output = self.norm(output)
- if self.return_intermediate:
- intermediate.pop()
- intermediate.append(output)
-
- if self.return_intermediate:
- return torch.stack(intermediate)
-
- return output.unsqueeze(0)
-
-
-class TransformerEncoderLayer(nn.Module):
- def __init__(
- self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
- ):
- super().__init__()
- self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
- # Implementation of Feedforward model
- self.linear1 = nn.Linear(d_model, dim_feedforward)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(dim_feedforward, d_model)
-
- self.norm1 = nn.LayerNorm(d_model)
- self.norm2 = nn.LayerNorm(d_model)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
-
- self.activation = _get_activation_fn(activation)
- self.normalize_before = normalize_before
-
- def with_pos_embed(self, tensor, pos: Optional[Tensor]):
- return tensor if pos is None else tensor + pos
-
- def forward_post(
- self,
- src,
- src_mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- ):
- q = k = self.with_pos_embed(src, pos)
- src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
- src = src + self.dropout1(src2)
- src = self.norm1(src)
- src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
- src = src + self.dropout2(src2)
- src = self.norm2(src)
- return src
-
- def forward_pre(
- self,
- src,
- src_mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- ):
- src2 = self.norm1(src)
- q = k = self.with_pos_embed(src2, pos)
- src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
- src = src + self.dropout1(src2)
- src2 = self.norm2(src)
- src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
- src = src + self.dropout2(src2)
- return src
-
- def forward(
- self,
- src,
- src_mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- ):
- if self.normalize_before:
- return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
- return self.forward_post(src, src_mask, src_key_padding_mask, pos)
-
-
-class TransformerDecoderLayer(nn.Module):
- def __init__(
- self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
- ):
- super().__init__()
- self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
- self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
- # Implementation of Feedforward model
- self.linear1 = nn.Linear(d_model, dim_feedforward)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(dim_feedforward, d_model)
-
- self.norm1 = nn.LayerNorm(d_model)
- self.norm2 = nn.LayerNorm(d_model)
- self.norm3 = nn.LayerNorm(d_model)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
- self.dropout3 = nn.Dropout(dropout)
-
- self.activation = _get_activation_fn(activation)
- self.normalize_before = normalize_before
-
- def with_pos_embed(self, tensor, pos: Optional[Tensor]):
- return tensor if pos is None else tensor + pos
-
- def forward_post(
- self,
- tgt,
- memory,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None,
- ):
- q = k = self.with_pos_embed(tgt, query_pos)
- tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
- tgt = tgt + self.dropout1(tgt2)
- tgt = self.norm1(tgt)
- tgt2 = self.multihead_attn(
- query=self.with_pos_embed(tgt, query_pos),
- key=self.with_pos_embed(memory, pos),
- value=memory,
- attn_mask=memory_mask,
- key_padding_mask=memory_key_padding_mask,
- )[0]
- tgt = tgt + self.dropout2(tgt2)
- tgt = self.norm2(tgt)
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
- tgt = tgt + self.dropout3(tgt2)
- tgt = self.norm3(tgt)
- return tgt
-
- def forward_pre(
- self,
- tgt,
- memory,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None,
- ):
- tgt2 = self.norm1(tgt)
- q = k = self.with_pos_embed(tgt2, query_pos)
- tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
- tgt = tgt + self.dropout1(tgt2)
- tgt2 = self.norm2(tgt)
- tgt2 = self.multihead_attn(
- query=self.with_pos_embed(tgt2, query_pos),
- key=self.with_pos_embed(memory, pos),
- value=memory,
- attn_mask=memory_mask,
- key_padding_mask=memory_key_padding_mask,
- )[0]
- tgt = tgt + self.dropout2(tgt2)
- tgt2 = self.norm3(tgt)
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
- tgt = tgt + self.dropout3(tgt2)
- return tgt
-
- def forward(
- self,
- tgt,
- memory,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None,
- ):
- if self.normalize_before:
- return self.forward_pre(
- tgt,
- memory,
- tgt_mask,
- memory_mask,
- tgt_key_padding_mask,
- memory_key_padding_mask,
- pos,
- query_pos,
- )
- return self.forward_post(
- tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
- )
-
-
-def _get_clones(module, n):
- return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
-
-
-def build_transformer(args):
- return Transformer(
- d_model=args.hidden_dim,
- dropout=args.dropout,
- nhead=args.nheads,
- dim_feedforward=args.dim_feedforward,
- num_encoder_layers=args.enc_layers,
- num_decoder_layers=args.dec_layers,
- normalize_before=args.pre_norm,
- return_intermediate_dec=True,
- )
-
-
-def _get_activation_fn(activation):
- """Return an activation function given a string"""
- if activation == "relu":
- return F.relu
- if activation == "gelu":
- return F.gelu
- if activation == "glu":
- return F.glu
- raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
diff --git a/lerobot/common/policies/act/utils.py b/lerobot/common/policies/act/utils.py
deleted file mode 100644
index 0d935839..00000000
--- a/lerobot/common/policies/act/utils.py
+++ /dev/null
@@ -1,478 +0,0 @@
-"""
-Misc functions, including distributed helpers.
-
-Mostly copy-paste from torchvision references.
-"""
-
-import datetime
-import os
-import pickle
-import subprocess
-import time
-from collections import defaultdict, deque
-from typing import List, Optional
-
-import torch
-import torch.distributed as dist
-
-# needed due to empty tensor bug in pytorch and torchvision 0.5
-import torchvision
-from packaging import version
-from torch import Tensor
-
-if version.parse(torchvision.__version__) < version.parse("0.7"):
- from torchvision.ops import _new_empty_tensor
- from torchvision.ops.misc import _output_size
-
-
-class SmoothedValue:
- """Track a series of values and provide access to smoothed values over a
- window or the global series average.
- """
-
- def __init__(self, window_size=20, fmt=None):
- if fmt is None:
- fmt = "{median:.4f} ({global_avg:.4f})"
- self.deque = deque(maxlen=window_size)
- self.total = 0.0
- self.count = 0
- self.fmt = fmt
-
- def update(self, value, n=1):
- self.deque.append(value)
- self.count += n
- self.total += value * n
-
- def synchronize_between_processes(self):
- """
- Warning: does not synchronize the deque!
- """
- if not is_dist_avail_and_initialized():
- return
- t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
- dist.barrier()
- dist.all_reduce(t)
- t = t.tolist()
- self.count = int(t[0])
- self.total = t[1]
-
- @property
- def median(self):
- d = torch.tensor(list(self.deque))
- return d.median().item()
-
- @property
- def avg(self):
- d = torch.tensor(list(self.deque), dtype=torch.float32)
- return d.mean().item()
-
- @property
- def global_avg(self):
- return self.total / self.count
-
- @property
- def max(self):
- return max(self.deque)
-
- @property
- def value(self):
- return self.deque[-1]
-
- def __str__(self):
- return self.fmt.format(
- median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
- )
-
-
-def all_gather(data):
- """
- Run all_gather on arbitrary picklable data (not necessarily tensors)
- Args:
- data: any picklable object
- Returns:
- list[data]: list of data gathered from each rank
- """
- world_size = get_world_size()
- if world_size == 1:
- return [data]
-
- # serialized to a Tensor
- buffer = pickle.dumps(data)
- storage = torch.ByteStorage.from_buffer(buffer)
- tensor = torch.ByteTensor(storage).to("cuda")
-
- # obtain Tensor size of each rank
- local_size = torch.tensor([tensor.numel()], device="cuda")
- size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
- dist.all_gather(size_list, local_size)
- size_list = [int(size.item()) for size in size_list]
- max_size = max(size_list)
-
- # receiving Tensor from all ranks
- # we pad the tensor because torch all_gather does not support
- # gathering tensors of different shapes
- tensor_list = []
- for _ in size_list:
- tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
- if local_size != max_size:
- padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
- tensor = torch.cat((tensor, padding), dim=0)
- dist.all_gather(tensor_list, tensor)
-
- data_list = []
- for size, tensor in zip(size_list, tensor_list, strict=False):
- buffer = tensor.cpu().numpy().tobytes()[:size]
- data_list.append(pickle.loads(buffer))
-
- return data_list
-
-
-def reduce_dict(input_dict, average=True):
- """
- Args:
- input_dict (dict): all the values will be reduced
- average (bool): whether to do average or sum
- Reduce the values in the dictionary from all processes so that all processes
- have the averaged results. Returns a dict with the same fields as
- input_dict, after reduction.
- """
- world_size = get_world_size()
- if world_size < 2:
- return input_dict
- with torch.no_grad():
- names = []
- values = []
- # sort the keys so that they are consistent across processes
- for k in sorted(input_dict.keys()):
- names.append(k)
- values.append(input_dict[k])
- values = torch.stack(values, dim=0)
- dist.all_reduce(values)
- if average:
- values /= world_size
- reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416
- return reduced_dict
-
-
-class MetricLogger:
- def __init__(self, delimiter="\t"):
- self.meters = defaultdict(SmoothedValue)
- self.delimiter = delimiter
-
- def update(self, **kwargs):
- for k, v in kwargs.items():
- if isinstance(v, torch.Tensor):
- v = v.item()
- assert isinstance(v, (float, int))
- self.meters[k].update(v)
-
- def __getattr__(self, attr):
- if attr in self.meters:
- return self.meters[attr]
- if attr in self.__dict__:
- return self.__dict__[attr]
- raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
-
- def __str__(self):
- loss_str = []
- for name, meter in self.meters.items():
- loss_str.append("{}: {}".format(name, str(meter)))
- return self.delimiter.join(loss_str)
-
- def synchronize_between_processes(self):
- for meter in self.meters.values():
- meter.synchronize_between_processes()
-
- def add_meter(self, name, meter):
- self.meters[name] = meter
-
- def log_every(self, iterable, print_freq, header=None):
- if not header:
- header = ""
- start_time = time.time()
- end = time.time()
- iter_time = SmoothedValue(fmt="{avg:.4f}")
- data_time = SmoothedValue(fmt="{avg:.4f}")
- space_fmt = ":" + str(len(str(len(iterable)))) + "d"
- if torch.cuda.is_available():
- log_msg = self.delimiter.join(
- [
- header,
- "[{0" + space_fmt + "}/{1}]",
- "eta: {eta}",
- "{meters}",
- "time: {time}",
- "data: {data}",
- "max mem: {memory:.0f}",
- ]
- )
- else:
- log_msg = self.delimiter.join(
- [
- header,
- "[{0" + space_fmt + "}/{1}]",
- "eta: {eta}",
- "{meters}",
- "time: {time}",
- "data: {data}",
- ]
- )
- mega_b = 1024.0 * 1024.0
- for i, obj in enumerate(iterable):
- data_time.update(time.time() - end)
- yield obj
- iter_time.update(time.time() - end)
- if i % print_freq == 0 or i == len(iterable) - 1:
- eta_seconds = iter_time.global_avg * (len(iterable) - i)
- eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
- if torch.cuda.is_available():
- print(
- log_msg.format(
- i,
- len(iterable),
- eta=eta_string,
- meters=str(self),
- time=str(iter_time),
- data=str(data_time),
- memory=torch.cuda.max_memory_allocated() / mega_b,
- )
- )
- else:
- print(
- log_msg.format(
- i,
- len(iterable),
- eta=eta_string,
- meters=str(self),
- time=str(iter_time),
- data=str(data_time),
- )
- )
- end = time.time()
- total_time = time.time() - start_time
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
-
-
-def get_sha():
- cwd = os.path.dirname(os.path.abspath(__file__))
-
- def _run(command):
- return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
-
- sha = "N/A"
- diff = "clean"
- branch = "N/A"
- try:
- sha = _run(["git", "rev-parse", "HEAD"])
- subprocess.check_output(["git", "diff"], cwd=cwd)
- diff = _run(["git", "diff-index", "HEAD"])
- diff = "has uncommited changes" if diff else "clean"
- branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
- except Exception:
- pass
- message = f"sha: {sha}, status: {diff}, branch: {branch}"
- return message
-
-
-def collate_fn(batch):
- batch = list(zip(*batch, strict=False))
- batch[0] = nested_tensor_from_tensor_list(batch[0])
- return tuple(batch)
-
-
-def _max_by_axis(the_list):
- # type: (List[List[int]]) -> List[int]
- maxes = the_list[0]
- for sublist in the_list[1:]:
- for index, item in enumerate(sublist):
- maxes[index] = max(maxes[index], item)
- return maxes
-
-
-class NestedTensor:
- def __init__(self, tensors, mask: Optional[Tensor]):
- self.tensors = tensors
- self.mask = mask
-
- def to(self, device):
- # type: (Device) -> NestedTensor # noqa
- cast_tensor = self.tensors.to(device)
- mask = self.mask
- if mask is not None:
- assert mask is not None
- cast_mask = mask.to(device)
- else:
- cast_mask = None
- return NestedTensor(cast_tensor, cast_mask)
-
- def decompose(self):
- return self.tensors, self.mask
-
- def __repr__(self):
- return str(self.tensors)
-
-
-def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
- # TODO make this more general
- if tensor_list[0].ndim == 3:
- if torchvision._is_tracing():
- # nested_tensor_from_tensor_list() does not export well to ONNX
- # call _onnx_nested_tensor_from_tensor_list() instead
- return _onnx_nested_tensor_from_tensor_list(tensor_list)
-
- # TODO make it support different-sized images
- max_size = _max_by_axis([list(img.shape) for img in tensor_list])
- # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
- batch_shape = [len(tensor_list)] + max_size
- b, c, h, w = batch_shape
- dtype = tensor_list[0].dtype
- device = tensor_list[0].device
- tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
- mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
- for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
- pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
- m[: img.shape[1], : img.shape[2]] = False
- else:
- raise ValueError("not supported")
- return NestedTensor(tensor, mask)
-
-
-# _onnx_nested_tensor_from_tensor_list() is an implementation of
-# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
-@torch.jit.unused
-def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
- max_size = []
- for i in range(tensor_list[0].dim()):
- max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(
- torch.int64
- )
- max_size.append(max_size_i)
- max_size = tuple(max_size)
-
- # work around for
- # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
- # m[: img.shape[1], :img.shape[2]] = False
- # which is not yet supported in onnx
- padded_imgs = []
- padded_masks = []
- for img in tensor_list:
- padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
- padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
- padded_imgs.append(padded_img)
-
- m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
- padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
- padded_masks.append(padded_mask.to(torch.bool))
-
- tensor = torch.stack(padded_imgs)
- mask = torch.stack(padded_masks)
-
- return NestedTensor(tensor, mask=mask)
-
-
-def setup_for_distributed(is_master):
- """
- This function disables printing when not in master process
- """
- import builtins as __builtin__
-
- builtin_print = __builtin__.print
-
- def print(*args, **kwargs):
- force = kwargs.pop("force", False)
- if is_master or force:
- builtin_print(*args, **kwargs)
-
- __builtin__.print = print
-
-
-def is_dist_avail_and_initialized():
- if not dist.is_available():
- return False
- if not dist.is_initialized():
- return False
- return True
-
-
-def get_world_size():
- if not is_dist_avail_and_initialized():
- return 1
- return dist.get_world_size()
-
-
-def get_rank():
- if not is_dist_avail_and_initialized():
- return 0
- return dist.get_rank()
-
-
-def is_main_process():
- return get_rank() == 0
-
-
-def save_on_master(*args, **kwargs):
- if is_main_process():
- torch.save(*args, **kwargs)
-
-
-def init_distributed_mode(args):
- if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
- args.rank = int(os.environ["RANK"])
- args.world_size = int(os.environ["WORLD_SIZE"])
- args.gpu = int(os.environ["LOCAL_RANK"])
- elif "SLURM_PROCID" in os.environ:
- args.rank = int(os.environ["SLURM_PROCID"])
- args.gpu = args.rank % torch.cuda.device_count()
- else:
- print("Not using distributed mode")
- args.distributed = False
- return
-
- args.distributed = True
-
- torch.cuda.set_device(args.gpu)
- args.dist_backend = "nccl"
- print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
- torch.distributed.init_process_group(
- backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
- )
- torch.distributed.barrier()
- setup_for_distributed(args.rank == 0)
-
-
-@torch.no_grad()
-def accuracy(output, target, topk=(1,)):
- """Computes the precision@k for the specified values of k"""
- if target.numel() == 0:
- return [torch.zeros([], device=output.device)]
- maxk = max(topk)
- batch_size = target.size(0)
-
- _, pred = output.topk(maxk, 1, True, True)
- pred = pred.t()
- correct = pred.eq(target.view(1, -1).expand_as(pred))
-
- res = []
- for k in topk:
- correct_k = correct[:k].view(-1).float().sum(0)
- res.append(correct_k.mul_(100.0 / batch_size))
- return res
-
-
-def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
- # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
- """
- Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
- This will eventually be supported natively by PyTorch, and this
- class can go away.
- """
- if version.parse(torchvision.__version__) < version.parse("0.7"):
- if input.numel() > 0:
- return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
-
- output_shape = _output_size(2, input, size, scale_factor)
- output_shape = list(input.shape[:-2]) + list(output_shape)
- return _new_empty_tensor(input, output_shape)
- else:
- return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py
index 7719fdde..f7432db3 100644
--- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py
+++ b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py
@@ -244,8 +244,10 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
return result
def compute_loss(self, batch):
- assert "valid_mask" not in batch
- nobs = batch["obs"]
+ nobs = {
+ "image": batch["observation.image"],
+ "agent_pos": batch["observation.state"],
+ }
nactions = batch["action"]
batch_size = nactions.shape[0]
horizon = nactions.shape[1]
@@ -303,6 +305,11 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
loss = F.mse_loss(pred, target, reduction="none")
loss = loss * loss_mask.type(loss.dtype)
- loss = reduce(loss, "b ... -> b (...)", "mean")
+
+ if "action_is_pad" in batch:
+ in_episode_bound = ~batch["action_is_pad"]
+ loss = loss * in_episode_bound[:, :, None].type(loss.dtype)
+
+ loss = reduce(loss, "b t c -> b", "mean", b=batch_size)
loss = loss.mean()
return loss
diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py
index 82f39b28..9785358b 100644
--- a/lerobot/common/policies/diffusion/policy.py
+++ b/lerobot/common/policies/diffusion/policy.py
@@ -1,18 +1,20 @@
import copy
import logging
import time
+from collections import deque
import hydra
import torch
+from torch import nn
-from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
+from lerobot.common.policies.utils import populate_queues
from lerobot.common.utils import get_safe_torch_device
-class DiffusionPolicy(AbstractPolicy):
+class DiffusionPolicy(nn.Module):
name = "diffusion"
def __init__(
@@ -38,8 +40,12 @@ class DiffusionPolicy(AbstractPolicy):
# parameters passed to step
**kwargs,
):
- super().__init__(n_action_steps)
+ super().__init__()
self.cfg = cfg
+ self.n_obs_steps = n_obs_steps
+ self.n_action_steps = n_action_steps
+ # queues are populated during rollout of the policy, they contain the n latest observations and actions
+ self._queues = None
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
@@ -100,75 +106,51 @@ class DiffusionPolicy(AbstractPolicy):
last_epoch=self.global_step - 1,
)
+ def reset(self):
+ """
+ Clear observation and action queues. Should be called on `env.reset()`
+ """
+ self._queues = {
+ "observation.image": deque(maxlen=self.n_obs_steps),
+ "observation.state": deque(maxlen=self.n_obs_steps),
+ "action": deque(maxlen=self.n_action_steps),
+ }
+
@torch.no_grad()
- def select_actions(self, observation, step_count):
+ def select_action(self, batch, step):
"""
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
"""
- # TODO(rcadene): remove unused step_count
- del step_count
+ # TODO(rcadene): remove unused step
+ del step
+ assert "observation.image" in batch
+ assert "observation.state" in batch
+ assert len(batch) == 2
- obs_dict = {
- "image": observation["image"],
- "agent_pos": observation["state"],
- }
- if self.training:
- out = self.diffusion.predict_action(obs_dict)
- else:
- out = self.ema_diffusion.predict_action(obs_dict)
- action = out["action"]
+ self._queues = populate_queues(self._queues, batch)
+
+ if len(self._queues["action"]) == 0:
+ # stack n latest observations from the queue
+ batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
+
+ obs_dict = {
+ "image": batch["observation.image"],
+ "agent_pos": batch["observation.state"],
+ }
+ if self.training:
+ out = self.diffusion.predict_action(obs_dict)
+ else:
+ out = self.ema_diffusion.predict_action(obs_dict)
+ self._queues["action"].extend(out["action"].transpose(0, 1))
+
+ action = self._queues["action"].popleft()
return action
- def update(self, replay_buffer, step):
+ def forward(self, batch, step):
start_time = time.time()
self.diffusion.train()
- num_slices = self.cfg.batch_size
- batch_size = self.cfg.horizon * num_slices
-
- assert batch_size % self.cfg.horizon == 0
- assert batch_size % num_slices == 0
-
- def process_batch(batch, horizon, num_slices):
- # trajectory t = 64, horizon h = 16
- # (t h) ... -> t h ...
- batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
-
- # |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16
- # |o|o| observations: 2
- # | |a|a|a|a|a|a|a|a| actions executed: 8
- # |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16
- # note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model
-
- image = batch["observation", "image"]
- state = batch["observation", "state"]
- action = batch["action"]
- assert image.shape[1] == horizon
- assert state.shape[1] == horizon
- assert action.shape[1] == horizon
-
- if not (horizon == 16 and self.cfg.n_obs_steps == 2):
- raise NotImplementedError()
-
- # keep first 2 observations of the slice corresponding to t=[-1,0]
- image = image[:, : self.cfg.n_obs_steps]
- state = state[:, : self.cfg.n_obs_steps]
-
- out = {
- "obs": {
- "image": image.to(self.device, non_blocking=True),
- "agent_pos": state.to(self.device, non_blocking=True),
- },
- "action": action.to(self.device, non_blocking=True),
- }
- return out
-
- batch = replay_buffer.sample(batch_size)
- batch = process_batch(batch, self.cfg.horizon, num_slices)
-
- data_s = time.time() - start_time
-
loss = self.diffusion.compute_loss(batch)
loss.backward()
@@ -189,7 +171,6 @@ class DiffusionPolicy(AbstractPolicy):
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.lr_scheduler.get_last_lr()[0],
- "data_s": data_s,
"update_s": time.time() - start_time,
}
diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py
index 934f0962..9077d4d0 100644
--- a/lerobot/common/policies/factory.py
+++ b/lerobot/common/policies/factory.py
@@ -1,11 +1,10 @@
def make_policy(cfg):
- if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
- raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
-
if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
- policy = TDMPCPolicy(cfg.policy, cfg.device)
+ policy = TDMPCPolicy(
+ cfg.policy, n_obs_steps=cfg.n_obs_steps, n_action_steps=cfg.n_action_steps, device=cfg.device
+ )
elif cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
@@ -17,24 +16,24 @@ def make_policy(cfg):
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
- n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
+ # n_obs_steps=cfg.n_obs_steps,
+ # n_action_steps=cfg.n_action_steps,
**cfg.policy,
)
elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
- policy = ActionChunkingTransformerPolicy(
- cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
- )
+ policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
+ policy.to(cfg.device)
else:
raise ValueError(cfg.policy.name)
if cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
- if "offline" in cfg.pretrained_model_path:
+ if "offline" in cfg.policy.pretrained_model_path:
policy.step[0] = 25000
- elif "final" in cfg.pretrained_model_path:
+ elif "final" in cfg.policy.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()
diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py
index 64dcc94d..14728576 100644
--- a/lerobot/common/policies/tdmpc/policy.py
+++ b/lerobot/common/policies/tdmpc/policy.py
@@ -1,6 +1,7 @@
# ruff: noqa: N806
import time
+from collections import deque
from copy import deepcopy
import einops
@@ -9,7 +10,7 @@ import torch
import torch.nn as nn
import lerobot.common.policies.tdmpc.helper as h
-from lerobot.common.policies.abstract import AbstractPolicy
+from lerobot.common.policies.utils import populate_queues
from lerobot.common.utils import get_safe_torch_device
FIRST_FRAME = 0
@@ -87,16 +88,18 @@ class TOLD(nn.Module):
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
-class TDMPCPolicy(AbstractPolicy):
+class TDMPCPolicy(nn.Module):
"""Implementation of TD-MPC learning + inference."""
name = "tdmpc"
- def __init__(self, cfg, device):
- super().__init__(None)
+ def __init__(self, cfg, n_obs_steps, n_action_steps, device):
+ super().__init__()
self.action_dim = cfg.action_dim
self.cfg = cfg
+ self.n_obs_steps = n_obs_steps
+ self.n_action_steps = n_action_steps
self.device = get_safe_torch_device(device)
self.std = h.linear_schedule(cfg.std_schedule, 0)
self.model = TOLD(cfg)
@@ -107,7 +110,6 @@ class TDMPCPolicy(AbstractPolicy):
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.model.eval()
self.model_target.eval()
- self.batch_size = cfg.batch_size
self.register_buffer("step", torch.zeros(1))
@@ -128,20 +130,54 @@ class TDMPCPolicy(AbstractPolicy):
self.model.load_state_dict(d["model"])
self.model_target.load_state_dict(d["model_target"])
- @torch.no_grad()
- def select_actions(self, observation, step_count):
- if observation["image"].shape[0] != 1:
- raise NotImplementedError("Batch size > 1 not handled")
-
- t0 = step_count.item() == 0
-
- obs = {
- # TODO(rcadene): remove contiguous hack...
- "rgb": observation["image"].contiguous(),
- "state": observation["state"].contiguous(),
+ def reset(self):
+ """
+ Clear observation and action queues. Should be called on `env.reset()`
+ """
+ self._queues = {
+ "observation.image": deque(maxlen=self.n_obs_steps),
+ "observation.state": deque(maxlen=self.n_obs_steps),
+ "action": deque(maxlen=self.n_action_steps),
}
- # Note: unsqueeze needed because `act` still uses non-batch logic.
- action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0)
+
+ @torch.no_grad()
+ def select_action(self, batch, step):
+ assert "observation.image" in batch
+ assert "observation.state" in batch
+ assert len(batch) == 2
+
+ self._queues = populate_queues(self._queues, batch)
+
+ t0 = step == 0
+
+ self.eval()
+
+ if len(self._queues["action"]) == 0:
+ batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
+
+ if self.n_obs_steps == 1:
+ # hack to remove the time dimension
+ for key in batch:
+ assert batch[key].shape[1] == 1
+ batch[key] = batch[key][:, 0]
+
+ actions = []
+ batch_size = batch["observation.image"].shape[0]
+ for i in range(batch_size):
+ obs = {
+ "rgb": batch["observation.image"][[i]],
+ "state": batch["observation.state"][[i]],
+ }
+ # Note: unsqueeze needed because `act` still uses non-batch logic.
+ action = self.act(obs, t0=t0, step=self.step)
+ actions.append(action)
+ action = torch.stack(actions)
+
+ # tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time
+ if i in range(self.n_action_steps):
+ self._queues["action"].append(action)
+
+ action = self._queues["action"].popleft()
return action
@torch.no_grad()
@@ -290,117 +326,54 @@ class TDMPCPolicy(AbstractPolicy):
def _td_target(self, next_z, reward, mask):
"""Compute the TD-target from a reward and the observation at the following time step."""
next_v = self.model.V(next_z)
- td_target = reward + self.cfg.discount * mask * next_v
+ td_target = reward + self.cfg.discount * mask * next_v.squeeze(2)
return td_target
- def update(self, replay_buffer, step, demo_buffer=None):
+ def forward(self, batch, step):
"""Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time()
- num_slices = self.cfg.batch_size
- batch_size = self.cfg.horizon * num_slices
+ batch_size = batch["index"].shape[0]
- if demo_buffer is None:
- demo_batch_size = 0
- else:
- # Update oversampling ratio
- demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
- demo_num_slices = int(demo_pc_batch * self.batch_size)
- demo_batch_size = self.cfg.horizon * demo_num_slices
- batch_size -= demo_batch_size
- num_slices -= demo_num_slices
- replay_buffer._sampler.num_slices = num_slices
- demo_buffer._sampler.num_slices = demo_num_slices
+ # TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
+ # instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
+ # batch size b = 256, time/horizon t = 5
+ # b t ... -> t b ...
+ for key in batch:
+ if batch[key].ndim > 1:
+ batch[key] = batch[key].transpose(1, 0)
- assert demo_batch_size % self.cfg.horizon == 0
- assert demo_batch_size % demo_num_slices == 0
+ action = batch["action"]
+ reward = batch["next.reward"]
+ # idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
+ done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
+ mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
+ weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device)
- assert batch_size % self.cfg.horizon == 0
- assert batch_size % num_slices == 0
+ obses = {
+ "rgb": batch["observation.image"],
+ "state": batch["observation.state"],
+ }
- # Sample from interaction dataset
-
- def process_batch(batch, horizon, num_slices):
- # trajectory t = 256, horizon h = 5
- # (t h) ... -> h t ...
- batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
-
- obs = {
- "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
- "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
- }
- action = batch["action"].to(self.device, non_blocking=True)
- next_obses = {
- "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
- "state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
- }
- reward = batch["next", "reward"].to(self.device, non_blocking=True)
-
- idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
- weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
-
- # TODO(rcadene): rearrange directly in offline dataset
- if reward.ndim == 2:
- reward = einops.rearrange(reward, "h t -> h t 1")
-
- assert reward.ndim == 3
- assert reward.shape == (horizon, num_slices, 1)
- # We dont use `batch["next", "done"]` since it only indicates the end of an
- # episode, but not the end of the trajectory of an episode.
- # Neither does `batch["next", "terminated"]`
- done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
- mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
- return obs, action, next_obses, reward, mask, done, idxs, weights
-
- batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
-
- obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
- batch, self.cfg.horizon, num_slices
- )
-
- # Sample from demonstration dataset
- if demo_batch_size > 0:
- demo_batch = demo_buffer.sample(demo_batch_size)
- (
- demo_obs,
- demo_action,
- demo_next_obses,
- demo_reward,
- demo_mask,
- demo_done,
- demo_idxs,
- demo_weights,
- ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
-
- if isinstance(obs, dict):
- obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
- next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
- else:
- obs = torch.cat([obs, demo_obs])
- next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
- action = torch.cat([action, demo_action], dim=1)
- reward = torch.cat([reward, demo_reward], dim=1)
- mask = torch.cat([mask, demo_mask], dim=1)
- done = torch.cat([done, demo_done], dim=1)
- idxs = torch.cat([idxs, demo_idxs])
- weights = torch.cat([weights, demo_weights])
+ shapes = {}
+ for k in obses:
+ shapes[k] = obses[k].shape
+ obses[k] = einops.rearrange(obses[k], "t b ... -> (t b) ... ")
# Apply augmentations
aug_tf = h.aug(self.cfg)
- obs = aug_tf(obs)
+ obses = aug_tf(obses)
- for k in next_obses:
- next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...")
- next_obses = aug_tf(next_obses)
- for k in next_obses:
- next_obses[k] = einops.rearrange(
- next_obses[k],
- "(h t) ... -> h t ...",
- h=self.cfg.horizon,
- t=self.cfg.batch_size,
- )
+ for k in obses:
+ t, b = shapes[k][:2]
+ obses[k] = einops.rearrange(obses[k], "(t b) ... -> t b ... ", b=b, t=t)
- horizon = self.cfg.horizon
+ obs, next_obses = {}, {}
+ for k in obses:
+ obs[k] = obses[k][0]
+ next_obses[k] = obses[k][1:].clone()
+
+ horizon = next_obses["rgb"].shape[0]
loss_mask = torch.ones_like(mask, device=self.device)
for t in range(1, horizon):
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
@@ -418,7 +391,7 @@ class TDMPCPolicy(AbstractPolicy):
td_targets = self._td_target(next_z, reward, mask)
# Latent rollout
- zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
+ zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device)
reward_preds = torch.empty_like(reward, device=self.device)
assert reward.shape[0] == horizon
z = self.model.encode(obs)
@@ -427,22 +400,21 @@ class TDMPCPolicy(AbstractPolicy):
for t in range(horizon):
z, reward_pred = self.model.next(z, action[t])
zs[t + 1] = z
- reward_preds[t] = reward_pred
+ reward_preds[t] = reward_pred.squeeze(1)
with torch.no_grad():
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
# Predictions
qs = self.model.Q(zs[:-1], action, return_type="all")
+ qs = qs.squeeze(3)
value_info["Q"] = qs.mean().item()
v = self.model.V(zs[:-1])
value_info["V"] = v.mean().item()
# Losses
- rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
- consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
- dim=0
- )
+ rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1)
+ consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0)
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
q_value_loss, priority_loss = 0, 0
for q in range(self.cfg.num_q):
@@ -450,7 +422,9 @@ class TDMPCPolicy(AbstractPolicy):
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
expectile = h.linear_schedule(self.cfg.expectile, step)
- v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
+ v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum(
+ dim=0
+ )
total_loss = (
self.cfg.consistency_coef * consistency_loss
@@ -459,7 +433,7 @@ class TDMPCPolicy(AbstractPolicy):
+ self.cfg.value_coef * v_value_loss
)
- weighted_loss = (total_loss.squeeze(1) * weights).mean()
+ weighted_loss = (total_loss * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
has_nan = torch.isnan(weighted_loss).item()
if has_nan:
@@ -472,19 +446,20 @@ class TDMPCPolicy(AbstractPolicy):
)
self.optim.step()
- if self.cfg.per:
- # Update priorities
- priorities = priority_loss.clamp(max=1e4).detach()
- has_nan = torch.isnan(priorities).any().item()
- if has_nan:
- print(f"priorities has nan: {priorities=}")
- else:
- replay_buffer.update_priority(
- idxs[:num_slices],
- priorities[:num_slices],
- )
- if demo_batch_size > 0:
- demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
+ # TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion
+ # if self.cfg.per:
+ # # Update priorities
+ # priorities = priority_loss.clamp(max=1e4).detach()
+ # has_nan = torch.isnan(priorities).any().item()
+ # if has_nan:
+ # print(f"priorities has nan: {priorities=}")
+ # else:
+ # replay_buffer.update_priority(
+ # idxs[:num_slices],
+ # priorities[:num_slices],
+ # )
+ # if demo_batch_size > 0:
+ # demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
# Update policy + target network
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
@@ -507,7 +482,7 @@ class TDMPCPolicy(AbstractPolicy):
"data_s": data_s,
"update_s": time.time() - start_time,
}
- info["demo_batch_size"] = demo_batch_size
+ # info["demo_batch_size"] = demo_batch_size
info["expectile"] = expectile
info.update(value_info)
info.update(pi_update_info)
diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py
new file mode 100644
index 00000000..b0503fe4
--- /dev/null
+++ b/lerobot/common/policies/utils.py
@@ -0,0 +1,10 @@
+def populate_queues(queues, batch):
+ for key in batch:
+ if len(queues[key]) != queues[key].maxlen:
+ # initialize by copying the first observation several times until the queue is full
+ while len(queues[key]) != queues[key].maxlen:
+ queues[key].append(batch[key])
+ else:
+ # add latest observation to the queue
+ queues[key].append(batch[key])
+ return queues
diff --git a/lerobot/common/transforms.py b/lerobot/common/transforms.py
index 4832c91b..ec967614 100644
--- a/lerobot/common/transforms.py
+++ b/lerobot/common/transforms.py
@@ -1,53 +1,49 @@
-from typing import Sequence
-
import torch
-from tensordict import TensorDictBase
-from tensordict.nn import dispatch
-from tensordict.utils import NestedKey
-from torchrl.envs.transforms import ObservationTransform, Transform
+from torchvision.transforms.v2 import Compose, Transform
-class Prod(ObservationTransform):
+def apply_inverse_transform(item, transform):
+ transforms = transform.transforms if isinstance(transform, Compose) else [transform]
+ for tf in transforms[::-1]:
+ if tf.invertible:
+ item = tf.inverse_transform(item)
+ else:
+ raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
+ return item
+
+
+class Prod(Transform):
invertible = True
- def __init__(self, in_keys: Sequence[NestedKey], prod: float):
+ def __init__(self, in_keys: list[str], prod: float):
super().__init__()
self.in_keys = in_keys
self.prod = prod
self.original_dtypes = {}
- def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
- # _reset is called once when the environment reset to normalize the first observation
- tensordict_reset = self._call(tensordict_reset)
- return tensordict_reset
-
- @dispatch(source="in_keys", dest="out_keys")
- def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
- return self._call(tensordict)
-
- def _call(self, td):
+ def forward(self, item):
for key in self.in_keys:
- if td.get(key, None) is None:
+ if key not in item:
continue
- self.original_dtypes[key] = td[key].dtype
- td[key] = td[key].type(torch.float32) * self.prod
- return td
+ self.original_dtypes[key] = item[key].dtype
+ item[key] = item[key].type(torch.float32) * self.prod
+ return item
- def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
+ def inverse_transform(self, item):
for key in self.in_keys:
- if td.get(key, None) is None:
+ if key not in item:
continue
- td[key] = (td[key] / self.prod).type(self.original_dtypes[key])
- return td
+ item[key] = (item[key] / self.prod).type(self.original_dtypes[key])
+ return item
- def transform_observation_spec(self, obs_spec):
- for key in self.in_keys:
- if obs_spec.get(key, None) is None:
- continue
- obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
- obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
- obs_spec[key].dtype = torch.float32
- return obs_spec
+ # def transform_observation_spec(self, obs_spec):
+ # for key in self.in_keys:
+ # if obs_spec.get(key, None) is None:
+ # continue
+ # obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
+ # obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
+ # obs_spec[key].dtype = torch.float32
+ # return obs_spec
class NormalizeTransform(Transform):
@@ -55,65 +51,50 @@ class NormalizeTransform(Transform):
def __init__(
self,
- stats: TensorDictBase,
- in_keys: Sequence[NestedKey] = None,
- out_keys: Sequence[NestedKey] | None = None,
- in_keys_inv: Sequence[NestedKey] | None = None,
- out_keys_inv: Sequence[NestedKey] | None = None,
+ stats: dict,
+ in_keys: list[str] = None,
+ out_keys: list[str] | None = None,
+ in_keys_inv: list[str] | None = None,
+ out_keys_inv: list[str] | None = None,
mode="mean_std",
):
- if out_keys is None:
- out_keys = in_keys
- if in_keys_inv is None:
- in_keys_inv = out_keys
- if out_keys_inv is None:
- out_keys_inv = in_keys
- super().__init__(
- in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
- )
+ super().__init__()
+ self.in_keys = in_keys
+ self.out_keys = in_keys if out_keys is None else out_keys
+ self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
+ self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
self.stats = stats
assert mode in ["mean_std", "min_max"]
self.mode = mode
- def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
- # _reset is called once when the environment reset to normalize the first observation
- tensordict_reset = self._call(tensordict_reset)
- return tensordict_reset
-
- @dispatch(source="in_keys", dest="out_keys")
- def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
- return self._call(tensordict)
-
- def _call(self, td: TensorDictBase) -> TensorDictBase:
+ def forward(self, item):
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
- # TODO(rcadene): don't know how to do `inkey not in td`
- if td.get(inkey, None) is None:
+ if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
- td[outkey] = (td[inkey] - mean) / (std + 1e-8)
+ item[outkey] = (item[inkey] - mean) / (std + 1e-8)
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
# normalize to [0,1]
- td[outkey] = (td[inkey] - min) / (max - min)
+ item[outkey] = (item[inkey] - min) / (max - min)
# normalize to [-1, 1]
- td[outkey] = td[outkey] * 2 - 1
- return td
+ item[outkey] = item[outkey] * 2 - 1
+ return item
- def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
+ def inverse_transform(self, item):
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
- # TODO(rcadene): don't know how to do `inkey not in td`
- if td.get(inkey, None) is None:
+ if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
- td[outkey] = td[inkey] * std + mean
+ item[outkey] = item[inkey] * std + mean
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
- td[outkey] = (td[inkey] + 1) / 2
- td[outkey] = td[outkey] * (max - min) + min
- return td
+ item[outkey] = (item[inkey] + 1) / 2
+ item[outkey] = item[outkey] * (max - min) + min
+ return item
diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py
index 7ed29334..373a3bbc 100644
--- a/lerobot/common/utils.py
+++ b/lerobot/common/utils.py
@@ -95,3 +95,16 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
)
cfg = hydra.compose(Path(config_path).stem, overrides)
return cfg
+
+
+def print_cuda_memory_usage():
+ """Use this function to locate and debug memory leak."""
+ import gc
+
+ gc.collect()
+ # Also clear the cache if you want to fully release the memory
+ torch.cuda.empty_cache()
+ print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
+ print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
+ print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
+ print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml
index 51569fea..6b836795 100644
--- a/lerobot/configs/env/aloha.yaml
+++ b/lerobot/configs/env/aloha.yaml
@@ -4,7 +4,7 @@ eval_episodes: 50
eval_freq: 7500
save_freq: 75000
log_freq: 250
-# TODO: same as simxarm, need to adjust
+# TODO: same as xarm, need to adjust
offline_steps: 25000
online_steps: 25000
@@ -14,11 +14,10 @@ dataset_id: aloha_sim_insertion_human
env:
name: aloha
- task: sim_insertion
+ task: AlohaInsertion-v0
from_pixels: True
pixels_only: False
image_size: [3, 480, 640]
- action_repeat: 1
episode_length: 400
fps: ${fps}
diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml
index 0050530e..a7097ffd 100644
--- a/lerobot/configs/env/pusht.yaml
+++ b/lerobot/configs/env/pusht.yaml
@@ -4,7 +4,7 @@ eval_episodes: 50
eval_freq: 7500
save_freq: 75000
log_freq: 250
-# TODO: same as simxarm, need to adjust
+# TODO: same as xarm, need to adjust
offline_steps: 25000
online_steps: 25000
@@ -14,11 +14,10 @@ dataset_id: pusht
env:
name: pusht
- task: pusht
+ task: PushT-v0
from_pixels: True
pixels_only: False
image_size: 96
- action_repeat: 1
episode_length: 300
fps: ${fps}
diff --git a/lerobot/configs/env/simxarm.yaml b/lerobot/configs/env/xarm.yaml
similarity index 86%
rename from lerobot/configs/env/simxarm.yaml
rename to lerobot/configs/env/xarm.yaml
index f79db8f7..bcba659e 100644
--- a/lerobot/configs/env/simxarm.yaml
+++ b/lerobot/configs/env/xarm.yaml
@@ -12,12 +12,11 @@ fps: 15
dataset_id: xarm_lift_medium
env:
- name: simxarm
- task: lift
+ name: xarm
+ task: XarmLift-v0
from_pixels: True
pixels_only: False
image_size: 84
- action_repeat: 2
episode_length: 25
fps: ${fps}
diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml
index a52c3f54..e2074b46 100644
--- a/lerobot/configs/policy/act.yaml
+++ b/lerobot/configs/policy/act.yaml
@@ -1,6 +1,6 @@
# @package _global_
-offline_steps: 1344000
+offline_steps: 80000
online_steps: 0
eval_episodes: 1
@@ -10,7 +10,6 @@ log_freq: 250
horizon: 100
n_obs_steps: 1
-n_latency_steps: 0
# when temporal_agg=False, n_action_steps=horizon
n_action_steps: ${horizon}
@@ -21,26 +20,27 @@ policy:
lr: 1e-5
lr_backbone: 1e-5
+ pretrained_backbone: true
weight_decay: 1e-4
grad_clip_norm: 10
backbone: resnet18
- num_queries: ${horizon} # chunk_size
horizon: ${horizon} # chunk_size
kl_weight: 10
- hidden_dim: 512
+ d_model: 512
dim_feedforward: 3200
+ vae_enc_layers: 4
enc_layers: 4
- dec_layers: 7
- nheads: 8
+ dec_layers: 1
+ num_heads: 8
#camera_names: [top, front_close, left_pillar, right_pillar]
camera_names: [top]
- position_embedding: sine
- masks: false
dilation: false
dropout: 0.1
pre_norm: false
+ activation: relu
+ latent_dim: 32
- vae: true
+ use_vae: true
batch_size: 8
@@ -51,8 +51,18 @@ policy:
utd: 1
n_obs_steps: ${n_obs_steps}
+ n_action_steps: ${n_action_steps}
temporal_agg: false
- state_dim: ???
- action_dim: ???
+ state_dim: 14
+ action_dim: 14
+
+ image_normalization:
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+
+ delta_timestamps:
+ observation.images.top: [0.0]
+ observation.state: [0.0]
+ action: "[i / ${fps} for i in range(${horizon})]"
diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml
index 4d6eedca..811ee824 100644
--- a/lerobot/configs/policy/diffusion.yaml
+++ b/lerobot/configs/policy/diffusion.yaml
@@ -16,7 +16,6 @@ seed: 100000
horizon: 16
n_obs_steps: 2
n_action_steps: 8
-n_latency_steps: 0
dataset_obs_steps: ${n_obs_steps}
past_action_visible: False
keypoint_visible_rate: 1.0
@@ -38,8 +37,8 @@ policy:
shape_meta: ${shape_meta}
horizon: ${horizon}
- # n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
n_obs_steps: ${n_obs_steps}
+ n_action_steps: ${n_action_steps}
num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null
@@ -64,6 +63,11 @@ policy:
lr_warmup_steps: 500
grad_clip_norm: 10
+ delta_timestamps:
+ observation.image: [-0.1, 0]
+ observation.state: [-0.1, 0]
+ action: [-0.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4]
+
noise_scheduler:
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
num_train_timesteps: 100
diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml
index ff0e6b04..4fd2b6bb 100644
--- a/lerobot/configs/policy/tdmpc.yaml
+++ b/lerobot/configs/policy/tdmpc.yaml
@@ -1,6 +1,6 @@
# @package _global_
-n_action_steps: 1
+n_action_steps: 2
n_obs_steps: 1
policy:
@@ -77,3 +77,9 @@ policy:
num_q: 5
mlp_dim: 512
latent_dim: 50
+
+ delta_timestamps:
+ observation.image: "[i / ${fps} for i in range(6)]"
+ observation.state: "[i / ${fps} for i in range(6)]"
+ action: "[i / ${fps} for i in range(5)]"
+ next.reward: "[i / ${fps} for i in range(5)]"
diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py
index 216769d6..d676623e 100644
--- a/lerobot/scripts/eval.py
+++ b/lerobot/scripts/eval.py
@@ -32,23 +32,21 @@ import json
import logging
import threading
import time
+from copy import deepcopy
from datetime import datetime as dt
from pathlib import Path
import einops
+import gymnasium as gym
import imageio
import numpy as np
import torch
-import tqdm
from huggingface_hub import snapshot_download
-from tensordict.nn import TensorDictModule
-from torchrl.envs import EnvBase
-from torchrl.envs.batched_envs import BatchedEnvBase
-from lerobot.common.datasets.factory import make_offline_buffer
+from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
+from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.logger import log_output_dir
-from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
@@ -58,89 +56,200 @@ def write_video(video_path, stacked_frames, fps):
def eval_policy(
- env: BatchedEnvBase,
- policy: AbstractPolicy,
- num_episodes: int = 10,
- max_steps: int = 30,
- save_video: bool = False,
+ env: gym.vector.VectorEnv,
+ policy: torch.nn.Module,
+ max_episodes_rendered: int = 0,
video_dir: Path = None,
- fps: int = 15,
- return_first_video: bool = False,
+ # TODO(rcadene): make it possible to overwrite fps? we should use env.fps
+ transform: callable = None,
+ seed=None,
):
+ fps = env.unwrapped.metadata["render_fps"]
+
if policy is not None:
policy.eval()
+ device = "cpu" if policy is None else next(policy.parameters()).device
+
start = time.time()
sum_rewards = []
max_rewards = []
- successes = []
+ all_successes = []
seeds = []
threads = [] # for video saving threads
episode_counter = 0 # for saving the correct number of videos
+ num_episodes = len(env.envs)
+
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
# needed as I'm currently taking a ceil.
- for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))):
- ep_frames = []
+ ep_frames = []
- def maybe_render_frame(env: EnvBase, _):
- if save_video or (return_first_video and i == 0): # noqa: B023
- ep_frames.append(env.render()) # noqa: B023
+ def render_frame(env):
+ # noqa: B023
+ eps_rendered = min(max_episodes_rendered, len(env.envs))
+ visu = np.stack([env.envs[i].render() for i in range(eps_rendered)])
+ ep_frames.append(visu) # noqa: B023
- # Clear the policy's action queue before the start of a new rollout.
- if policy is not None:
- policy.clear_action_queue()
+ for _ in range(num_episodes):
+ seeds.append("TODO")
- if env.is_closed:
- env.start() # needed to be able to get the seeds the first time as BatchedEnvs are lazy
- seeds.extend(env._next_seed)
+ if hasattr(policy, "reset"):
+ policy.reset()
+ else:
+ logging.warning(
+ f"Policy {policy} doesnt have a `reset` method. It is required if the policy relies on an internal state during rollout."
+ )
+
+ # reset the environment
+ observation, info = env.reset(seed=seed)
+ if max_episodes_rendered > 0:
+ render_frame(env)
+
+ observations = []
+ actions = []
+ # episode
+ # frame_id
+ # timestamp
+ rewards = []
+ successes = []
+ dones = []
+
+ done = torch.tensor([False for _ in env.envs])
+ step = 0
+ while not done.all():
+ # format from env keys to lerobot keys
+ observation = preprocess_observation(observation)
+ observations.append(deepcopy(observation))
+
+ # apply transform to normalize the observations
+ for key in observation:
+ observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
+
+ # send observation to device/gpu
+ observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
+
+ # get the next action for the environment
with torch.inference_mode():
- # TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
- # envs are done the first time. But we only use the first rollout. This is a waste of compute.
- rollout = env.rollout(
- max_steps=max_steps,
- policy=policy,
- auto_cast_to_device=True,
- callback=maybe_render_frame,
- break_when_any_done=env.batch_size[0] == 1,
+ action = policy.select_action(observation, step)
+
+ # apply inverse transform to unnormalize the action
+ action = postprocess_action(action, transform)
+
+ # apply the next action
+ observation, reward, terminated, truncated, info = env.step(action)
+ if max_episodes_rendered > 0:
+ render_frame(env)
+
+ # TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
+ action = torch.from_numpy(action)
+ reward = torch.from_numpy(reward)
+ terminated = torch.from_numpy(terminated)
+ truncated = torch.from_numpy(truncated)
+ # environment is considered done (no more steps), when success state is reached (terminated is True),
+ # or time limit is reached (truncated is True), or it was previsouly done.
+ done = terminated | truncated | done
+
+ if "final_info" in info:
+ # VectorEnv stores is_success into `info["final_info"][env_id]["is_success"]` instead of `info["is_success"]`
+ success = [
+ env_info["is_success"] if env_info is not None else False for env_info in info["final_info"]
+ ]
+ else:
+ success = [False for _ in env.envs]
+ success = torch.tensor(success)
+
+ actions.append(action)
+ rewards.append(reward)
+ dones.append(done)
+ successes.append(success)
+
+ step += 1
+
+ env.close()
+
+ # add the last observation when the env is done
+ observation = preprocess_observation(observation)
+ observations.append(deepcopy(observation))
+
+ new_obses = {}
+ for key in observations[0].keys(): # noqa: SIM118
+ new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
+ observations = new_obses
+ actions = torch.stack(actions, dim=1)
+ rewards = torch.stack(rewards, dim=1)
+ successes = torch.stack(successes, dim=1)
+ dones = torch.stack(dones, dim=1)
+
+ # Figure out where in each rollout sequence the first done condition was encountered (results after
+ # this won't be included).
+ # Note: this assumes that the shape of the done key is (batch_size, max_steps).
+ # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
+ done_indices = torch.argmax(dones.to(int), axis=1) # (batch_size, rollout_steps)
+ expand_done_indices = done_indices[:, None].expand(-1, step)
+ expand_step_indices = torch.arange(step)[None, :].expand(num_episodes, -1)
+ mask = (expand_step_indices <= expand_done_indices).int() # (batch_size, rollout_steps)
+ batch_sum_reward = einops.reduce((rewards * mask), "b n -> b", "sum")
+ batch_max_reward = einops.reduce((rewards * mask), "b n -> b", "max")
+ batch_success = einops.reduce((successes * mask), "b n -> b", "any")
+ sum_rewards.extend(batch_sum_reward.tolist())
+ max_rewards.extend(batch_max_reward.tolist())
+ all_successes.extend(batch_success.tolist())
+
+ # similar logic is implemented in dataset preprocessing
+ ep_dicts = []
+ num_episodes = dones.shape[0]
+ total_frames = 0
+ idx0 = idx1 = 0
+ data_ids_per_episode = {}
+ for ep_id in range(num_episodes):
+ num_frames = done_indices[ep_id].item() + 1
+ # TODO(rcadene): We need to add a missing last frame which is the observation
+ # of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
+ ep_dict = {
+ "action": actions[ep_id, :num_frames],
+ "episode": torch.tensor([ep_id] * num_frames),
+ "frame_id": torch.arange(0, num_frames, 1),
+ "timestamp": torch.arange(0, num_frames, 1) / fps,
+ "next.done": dones[ep_id, :num_frames],
+ "next.reward": rewards[ep_id, :num_frames].type(torch.float32),
+ }
+ for key in observations:
+ ep_dict[key] = observations[key][ep_id, :num_frames]
+ ep_dicts.append(ep_dict)
+
+ total_frames += num_frames
+ idx1 += num_frames
+
+ data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
+
+ idx0 = idx1
+
+ # similar logic is implemented in dataset preprocessing
+ data_dict = {}
+ keys = ep_dicts[0].keys()
+ for key in keys:
+ data_dict[key] = torch.cat([x[key] for x in ep_dicts])
+ data_dict["index"] = torch.arange(0, total_frames, 1)
+
+ if max_episodes_rendered > 0:
+ batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
+
+ for stacked_frames, done_index in zip(
+ batch_stacked_frames, done_indices.flatten().tolist(), strict=False
+ ):
+ if episode_counter >= num_episodes:
+ continue
+ video_dir.mkdir(parents=True, exist_ok=True)
+ video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
+ thread = threading.Thread(
+ target=write_video,
+ args=(str(video_path), stacked_frames[:done_index], fps),
)
- # Figure out where in each rollout sequence the first done condition was encountered (results after
- # this won't be included).
- # Note: this assumes that the shape of the done key is (batch_size, max_steps, 1).
- # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
- rollout_steps = rollout["next", "done"].shape[1]
- done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps)
- mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1)
- batch_sum_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "sum")
- batch_max_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "max")
- batch_success = einops.reduce((rollout["next", "success"] * mask), "b n 1 -> b", "any")
- sum_rewards.extend(batch_sum_reward.tolist())
- max_rewards.extend(batch_max_reward.tolist())
- successes.extend(batch_success.tolist())
+ thread.start()
+ threads.append(thread)
+ episode_counter += 1
- if save_video or (return_first_video and i == 0):
- batch_stacked_frames = np.stack(ep_frames) # (t, b, *)
- batch_stacked_frames = batch_stacked_frames.transpose(
- 1, 0, *range(2, batch_stacked_frames.ndim)
- ) # (b, t, *)
-
- if save_video:
- for stacked_frames, done_index in zip(
- batch_stacked_frames, done_indices.flatten().tolist(), strict=False
- ):
- if episode_counter >= num_episodes:
- continue
- video_dir.mkdir(parents=True, exist_ok=True)
- video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
- thread = threading.Thread(
- target=write_video,
- args=(str(video_path), stacked_frames[:done_index], fps),
- )
- thread.start()
- threads.append(thread)
- episode_counter += 1
-
- if return_first_video and i == 0:
- first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
+ videos = einops.rearrange(batch_stacked_frames, "b t h w c -> b t c h w")
for thread in threads:
thread.join()
@@ -158,22 +267,26 @@ def eval_policy(
zip(
sum_rewards[:num_episodes],
max_rewards[:num_episodes],
- successes[:num_episodes],
+ all_successes[:num_episodes],
seeds[:num_episodes],
strict=True,
)
)
],
"aggregated": {
- "avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
- "avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
- "pc_success": np.nanmean(successes[:num_episodes]) * 100,
+ "avg_sum_reward": float(np.nanmean(sum_rewards[:num_episodes])),
+ "avg_max_reward": float(np.nanmean(max_rewards[:num_episodes])),
+ "pc_success": float(np.nanmean(all_successes[:num_episodes]) * 100),
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
},
+ "episodes": {
+ "data_dict": data_dict,
+ "data_ids_per_episode": data_ids_per_episode,
+ },
}
- if return_first_video:
- return info, first_video
+ if max_episodes_rendered > 0:
+ info["videos"] = videos
return info
@@ -194,35 +307,29 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation.
- offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
+ transform = make_dataset(cfg, stats_path=stats_path).transform
logging.info("Making environment.")
- env = make_env(cfg, transform=offline_buffer.transform)
+ env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
- if cfg.policy.pretrained_model_path:
- policy = make_policy(cfg)
- policy = TensorDictModule(
- policy,
- in_keys=["observation", "step_count"],
- out_keys=["action"],
- )
- else:
- # when policy is None, rollout a random policy
- policy = None
+ logging.info("Making policy.")
+ policy = make_policy(cfg)
info = eval_policy(
env,
- policy=policy,
- save_video=True,
+ policy,
+ max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
- fps=cfg.env.fps,
- max_steps=cfg.env.episode_length,
- num_episodes=cfg.eval_episodes,
+ transform=transform,
+ seed=cfg.seed,
)
print(info["aggregated"])
# Save info
with open(Path(out_dir) / "eval_info.json", "w") as f:
+ # remove pytorch tensors which are not serializable to save the evaluation results only
+ del info["episodes"]
+ del info["videos"]
json.dump(info, f, indent=2)
logging.info("End of eval")
diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py
index 18c3715b..03506f2a 100644
--- a/lerobot/scripts/train.py
+++ b/lerobot/scripts/train.py
@@ -1,18 +1,21 @@
import logging
+from copy import deepcopy
from pathlib import Path
import hydra
-import numpy as np
import torch
-from tensordict.nn import TensorDictModule
-from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
-from torchrl.data.replay_buffers import PrioritizedSliceSampler
-from lerobot.common.datasets.factory import make_offline_buffer
+from lerobot.common.datasets.factory import make_dataset
+from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
-from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_global_seed
+from lerobot.common.utils import (
+ format_big_number,
+ get_safe_torch_device,
+ init_logging,
+ set_global_seed,
+)
from lerobot.scripts.eval import eval_policy
@@ -34,19 +37,18 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
train(cfg, out_dir=out_dir, job_name=job_name)
-def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
+def log_train_info(logger, info, step, cfg, dataset, is_offline):
loss = info["loss"]
grad_norm = info["grad_norm"]
lr = info["lr"]
- data_s = info["data_s"]
update_s = info["update_s"]
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.policy.batch_size
- avg_samples_per_ep = offline_buffer.num_samples / offline_buffer.num_episodes
+ avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
- num_epochs = num_samples / offline_buffer.num_samples
+ num_epochs = num_samples / dataset.num_samples
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
@@ -59,7 +61,6 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
f"grdn:{grad_norm:.3f}",
f"lr:{lr:0.1e}",
# in seconds
- f"data_s:{data_s:.3f}",
f"updt_s:{update_s:.3f}",
]
logging.info(" ".join(log_items))
@@ -73,7 +74,7 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
logger.log_dict(info, step, mode="train")
-def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline):
+def log_eval_info(logger, info, step, cfg, dataset, is_offline):
eval_s = info["eval_s"]
avg_sum_reward = info["avg_sum_reward"]
pc_success = info["pc_success"]
@@ -81,9 +82,9 @@ def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.policy.batch_size
- avg_samples_per_ep = offline_buffer.num_samples / offline_buffer.num_episodes
+ avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
- num_epochs = num_samples / offline_buffer.num_samples
+ num_epochs = num_samples / dataset.num_samples
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
@@ -107,6 +108,64 @@ def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline):
logger.log_dict(info, step, mode="eval")
+def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
+ """
+ Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
+
+ Parameters:
+ - n_off (int): Number of offline samples, each with a sampling weight of 1.
+ - n_on (int): Number of online samples.
+ - pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
+
+ The total weight of offline samples is n_off * 1.0.
+ The total weight of offline samples is n_on * w.
+ The total combined weight of all samples is n_off + n_on * w.
+ The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
+ We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
+ The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
+ """
+ assert 0.0 <= pc_on <= 1.0
+ return -(n_off * pc_on) / (n_on * (pc_on - 1))
+
+
+def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples):
+ data_dict = episodes["data_dict"]
+ data_ids_per_episode = episodes["data_ids_per_episode"]
+
+ if len(online_dataset) == 0:
+ # initialize online dataset
+ online_dataset.data_dict = data_dict
+ online_dataset.data_ids_per_episode = data_ids_per_episode
+ else:
+ # find episode index and data frame indices according to previous episode in online_dataset
+ start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1
+ start_index = online_dataset.data_dict["index"][-1].item() + 1
+ data_dict["episode"] += start_episode
+ data_dict["index"] += start_index
+
+ # extend online dataset
+ for key in data_dict:
+ # TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure
+ online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]])
+ for ep_id in data_ids_per_episode:
+ online_dataset.data_ids_per_episode[ep_id + start_episode] = (
+ data_ids_per_episode[ep_id] + start_index
+ )
+
+ # update the concatenated dataset length used during sampling
+ concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
+
+ # update the sampling weights for each frame so that online frames get sampled a certain percentage of times
+ len_online = len(online_dataset)
+ len_offline = len(concat_dataset) - len_online
+ weight_offline = 1.0
+ weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
+ sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
+
+ # update the total number of samples used during sampling
+ sampler.num_samples = len(concat_dataset)
+
+
def train(cfg: dict, out_dir=None, job_name=None):
if out_dir is None:
raise NotImplementedError()
@@ -124,30 +183,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed)
- logging.info("make_offline_buffer")
- offline_buffer = make_offline_buffer(cfg)
-
- # TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
- if cfg.policy.balanced_sampling:
- logging.info("make online_buffer")
- num_traj_per_batch = cfg.policy.batch_size
-
- online_sampler = PrioritizedSliceSampler(
- max_capacity=100_000,
- alpha=cfg.policy.per_alpha,
- beta=cfg.policy.per_beta,
- num_slices=num_traj_per_batch,
- strict_length=True,
- )
-
- online_buffer = TensorDictReplayBuffer(
- storage=LazyMemmapStorage(100_000),
- sampler=online_sampler,
- transform=offline_buffer.transform,
- )
+ logging.info("make_dataset")
+ offline_dataset = make_dataset(cfg)
logging.info("make_env")
- env = make_env(cfg, transform=offline_buffer.transform)
+ env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
logging.info("make_policy")
policy = make_policy(cfg)
@@ -155,8 +195,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
- td_policy = TensorDictModule(policy, in_keys=["observation", "step_count"], out_keys=["action"])
-
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
@@ -164,9 +202,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}")
- logging.info(f"{cfg.env.action_repeat=}")
- logging.info(f"{offline_buffer.num_samples=} ({format_big_number(offline_buffer.num_samples)})")
- logging.info(f"{offline_buffer.num_episodes=}")
+ logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
+ logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
@@ -174,18 +211,17 @@ def train(cfg: dict, out_dir=None, job_name=None):
def _maybe_eval_and_maybe_save(step):
if step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
- eval_info, first_video = eval_policy(
+ eval_info = eval_policy(
env,
- td_policy,
- num_episodes=cfg.eval_episodes,
- max_steps=cfg.env.episode_length,
- return_first_video=True,
+ policy,
video_dir=Path(out_dir) / "eval",
- save_video=True,
+ max_episodes_rendered=4,
+ transform=offline_dataset.transform,
+ seed=cfg.seed,
)
- log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_buffer, is_offline)
+ log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable:
- logger.log_video(first_video, step, mode="eval")
+ logger.log_video(eval_info["videos"][0], step, mode="eval")
logging.info("Resume training")
if cfg.save_model and step % cfg.save_freq == 0:
@@ -193,17 +229,33 @@ def train(cfg: dict, out_dir=None, job_name=None):
logger.save_model(policy, identifier=step)
logging.info("Resume training")
- step = 0 # number of policy update (forward + backward + optim)
+ # create dataloader for offline training
+ dataloader = torch.utils.data.DataLoader(
+ offline_dataset,
+ num_workers=4,
+ batch_size=cfg.policy.batch_size,
+ shuffle=True,
+ pin_memory=cfg.device != "cpu",
+ drop_last=False,
+ )
+ dl_iter = cycle(dataloader)
+ step = 0 # number of policy update (forward + backward + optim)
is_offline = True
for offline_step in range(cfg.offline_steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
- # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
policy.train()
- train_info = policy.update(offline_buffer, step)
+ batch = next(dl_iter)
+
+ for key in batch:
+ batch[key] = batch[key].to(cfg.device, non_blocking=True)
+
+ train_info = policy(batch, step)
+
+ # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0:
- log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
+ log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
# step + 1.
@@ -211,59 +263,60 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1
- demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None
+ # create an env dedicated to online episodes collection from policy rollout
+ rollout_env = make_env(cfg, num_parallel_envs=1)
+
+ # create an empty online dataset similar to offline dataset
+ online_dataset = deepcopy(offline_dataset)
+ online_dataset.data_dict = {}
+ online_dataset.data_ids_per_episode = {}
+
+ # create dataloader for online training
+ concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
+ weights = [1.0] * len(concat_dataset)
+ sampler = torch.utils.data.WeightedRandomSampler(
+ weights, num_samples=len(concat_dataset), replacement=True
+ )
+ dataloader = torch.utils.data.DataLoader(
+ concat_dataset,
+ num_workers=4,
+ batch_size=cfg.policy.batch_size,
+ sampler=sampler,
+ pin_memory=cfg.device != "cpu",
+ drop_last=False,
+ )
+ dl_iter = cycle(dataloader)
+
online_step = 0
is_offline = False
for env_step in range(cfg.online_steps):
if env_step == 0:
logging.info("Start online training by interacting with environment")
- # TODO: add configurable number of rollout? (default=1)
+
with torch.no_grad():
- rollout = env.rollout(
- max_steps=cfg.env.episode_length,
- policy=td_policy,
- auto_cast_to_device=True,
+ eval_info = eval_policy(
+ rollout_env,
+ policy,
+ transform=offline_dataset.transform,
+ seed=cfg.seed,
)
- assert (
- len(rollout.batch_size) == 2
- ), "2 dimensions expected: number of env in parallel x max number of steps during rollout"
-
- num_parallel_env = rollout.batch_size[0]
- if num_parallel_env != 1:
- # TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests
- raise NotImplementedError()
-
- num_max_steps = rollout.batch_size[1]
- assert num_max_steps <= cfg.env.episode_length
-
- # reshape to have a list of steps to insert into online_buffer
- rollout = rollout.reshape(num_parallel_env * num_max_steps)
-
- # set same episode index for all time steps contained in this rollout
- rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
- online_buffer.extend(rollout)
-
- ep_sum_reward = rollout["next", "reward"].sum()
- ep_max_reward = rollout["next", "reward"].max()
- ep_success = rollout["next", "success"].any()
- rollout_info = {
- "avg_sum_reward": np.nanmean(ep_sum_reward),
- "avg_max_reward": np.nanmean(ep_max_reward),
- "pc_success": np.nanmean(ep_success) * 100,
- "env_step": env_step,
- "ep_length": len(rollout),
- }
+ online_pc_sampling = cfg.get("demo_schedule", 0.5)
+ add_episodes_inplace(
+ eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling
+ )
for _ in range(cfg.policy.utd):
- train_info = policy.update(
- online_buffer,
- step,
- demo_buffer=demo_buffer,
- )
+ policy.train()
+ batch = next(dl_iter)
+
+ for key in batch:
+ batch[key] = batch[key].to(cfg.device, non_blocking=True)
+
+ train_info = policy(batch, step)
+
if step % cfg.log_freq == 0:
- train_info.update(rollout_info)
- log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
+ log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
# in step + 1.
diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py
index 3dd7cdfa..4b7b7d6c 100644
--- a/lerobot/scripts/visualize_dataset.py
+++ b/lerobot/scripts/visualize_dataset.py
@@ -6,11 +6,8 @@ import einops
import hydra
import imageio
import torch
-from torchrl.data.replay_buffers import (
- SamplerWithoutReplacement,
-)
-from lerobot.common.datasets.factory import make_offline_buffer
+from lerobot.common.datasets.factory import make_dataset
from lerobot.common.logger import log_output_dir
from lerobot.common.utils import init_logging
@@ -39,85 +36,62 @@ def visualize_dataset(cfg: dict, out_dir=None):
init_logging()
log_output_dir(out_dir)
- # we expect frames of each episode to be stored next to each others sequentially
- sampler = SamplerWithoutReplacement(
- shuffle=False,
- )
-
- logging.info("make_offline_buffer")
- offline_buffer = make_offline_buffer(
+ logging.info("make_dataset")
+ dataset = make_dataset(
cfg,
- overwrite_sampler=sampler,
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
normalize=False,
- overwrite_batch_size=1,
- overwrite_prefetch=12,
)
logging.info("Start rendering episodes from offline buffer")
- video_paths = render_dataset(offline_buffer, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
+ video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
for video_path in video_paths:
logging.info(video_path)
-def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
+def render_dataset(dataset, out_dir, max_num_episodes):
out_dir = Path(out_dir)
video_paths = []
threads = []
- frames = {}
- current_ep_idx = 0
- logging.info(f"Visualizing episode {current_ep_idx}")
- for i in range(max_num_samples):
- # TODO(rcadene): make it work with bsize > 1
- ep_td = offline_buffer.sample(1)
- ep_idx = ep_td["episode"][FIRST_FRAME].item()
- # TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
- num_frames_left = offline_buffer._sampler._sample_list.numel()
- episode_is_done = ep_idx != current_ep_idx
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ num_workers=4,
+ batch_size=1,
+ shuffle=False,
+ )
+ dl_iter = iter(dataloader)
- if episode_is_done:
- logging.info(f"Rendering episode {current_ep_idx}")
+ num_episodes = len(dataset.data_ids_per_episode)
+ for ep_id in range(min(max_num_episodes, num_episodes)):
+ logging.info(f"Rendering episode {ep_id}")
- for im_key in offline_buffer.image_keys:
- if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
+ frames = {}
+ for _ in dataset.data_ids_per_episode[ep_id]:
+ item = next(dl_iter)
+
+ for im_key in dataset.image_keys:
# when first frame of episode, initialize frames dict
if im_key not in frames:
frames[im_key] = []
# add current frame to list of frames to render
- frames[im_key].append(ep_td[im_key])
+ frames[im_key].append(item[im_key])
+
+ out_dir.mkdir(parents=True, exist_ok=True)
+ for im_key in dataset.image_keys:
+ if len(dataset.image_keys) > 1:
+ im_name = im_key.replace("observation.images.", "")
+ video_path = out_dir / f"episode_{ep_id}_{im_name}.mp4"
else:
- # When episode has no more frame in its list of observation,
- # one frame still remains. It is the result of the last action taken.
- # It is stored in `"next"`, so we add it to the list of frames to render.
- frames[im_key].append(ep_td["next"][im_key])
+ video_path = out_dir / f"episode_{ep_id}.mp4"
+ video_paths.append(video_path)
- out_dir.mkdir(parents=True, exist_ok=True)
- if len(offline_buffer.image_keys) > 1:
- camera = im_key[-1]
- video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
- else:
- video_path = out_dir / f"episode_{current_ep_idx}.mp4"
- video_paths.append(str(video_path))
-
- thread = threading.Thread(
- target=cat_and_write_video,
- args=(str(video_path), frames[im_key], fps),
- )
- thread.start()
- threads.append(thread)
-
- current_ep_idx = ep_idx
-
- # reset list of frames
- del frames[im_key]
-
- if num_frames_left == 0:
- logging.info("Ran out of frames")
- break
-
- if current_ep_idx == NUM_EPISODES_TO_RENDER:
- break
+ thread = threading.Thread(
+ target=cat_and_write_video,
+ args=(str(video_path), frames[im_key], dataset.fps),
+ )
+ thread.start()
+ threads.append(thread)
for thread in threads:
thread.join()
diff --git a/poetry.lock b/poetry.lock
index 72397001..0133b3ed 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
[[package]]
name = "absl-py"
@@ -521,7 +521,7 @@ files = [
name = "dm-control"
version = "1.0.14"
description = "Continuous control environments and MuJoCo Python bindings."
-optional = false
+optional = true
python-versions = ">=3.8"
files = [
{file = "dm_control-1.0.14-py3-none-any.whl", hash = "sha256:883c63244a7ebf598700a97564ed19fffd3479ca79efd090aed881609cdb9fc6"},
@@ -552,7 +552,7 @@ hdf5 = ["h5py"]
name = "dm-env"
version = "1.6"
description = "A Python interface for Reinforcement Learning environments."
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "dm-env-1.6.tar.gz", hash = "sha256:a436eb1c654c39e0c986a516cee218bea7140b510fceff63f97eb4fcff3d93de"},
@@ -568,7 +568,7 @@ numpy = "*"
name = "dm-tree"
version = "0.1.8"
description = "Tree is a library for working with nested data structures."
-optional = false
+optional = true
python-versions = "*"
files = [
{file = "dm-tree-0.1.8.tar.gz", hash = "sha256:0fcaabbb14e7980377439e7140bd05552739ca5e515ecb3119f234acee4b9430"},
@@ -692,18 +692,18 @@ files = [
[[package]]
name = "filelock"
-version = "3.13.1"
+version = "3.13.3"
description = "A platform independent file lock."
optional = false
python-versions = ">=3.8"
files = [
- {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"},
- {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"},
+ {file = "filelock-3.13.3-py3-none-any.whl", hash = "sha256:5ffa845303983e7a0b7ae17636509bc97997d58afeafa72fb141a17b152284cb"},
+ {file = "filelock-3.13.3.tar.gz", hash = "sha256:a79895a25bbefdf55d1a2a0a80968f7dbb28edcd6d4234a0afb3f37ecde4b546"},
]
[package.extras]
-docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"]
-testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"]
+docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"]
+testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"]
typing = ["typing-extensions (>=4.8)"]
[[package]]
@@ -777,26 +777,27 @@ smmap = ">=3.0.1,<6"
[[package]]
name = "gitpython"
-version = "3.1.42"
+version = "3.1.43"
description = "GitPython is a Python library used to interact with Git repositories"
optional = false
python-versions = ">=3.7"
files = [
- {file = "GitPython-3.1.42-py3-none-any.whl", hash = "sha256:1bf9cd7c9e7255f77778ea54359e54ac22a72a5b51288c457c881057b7bb9ecd"},
- {file = "GitPython-3.1.42.tar.gz", hash = "sha256:2d99869e0fef71a73cbd242528105af1d6c1b108c60dfabd994bf292f76c3ceb"},
+ {file = "GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff"},
+ {file = "GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c"},
]
[package.dependencies]
gitdb = ">=4.0.1,<5"
[package.extras]
-test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar"]
+doc = ["sphinx (==4.3.2)", "sphinx-autodoc-typehints", "sphinx-rtd-theme", "sphinxcontrib-applehelp (>=1.0.2,<=1.0.4)", "sphinxcontrib-devhelp (==1.0.2)", "sphinxcontrib-htmlhelp (>=2.0.0,<=2.0.1)", "sphinxcontrib-qthelp (==1.0.3)", "sphinxcontrib-serializinghtml (==1.1.5)"]
+test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"]
[[package]]
name = "glfw"
version = "2.7.0"
description = "A ctypes-based wrapper for GLFW3."
-optional = false
+optional = true
python-versions = "*"
files = [
{file = "glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-macosx_10_6_intel.whl", hash = "sha256:bd82849edcceda4e262bd1227afaa74b94f9f0731c1197863cd25c15bfc613fc"},
@@ -879,6 +880,69 @@ files = [
[package.extras]
protobuf = ["grpcio-tools (>=1.62.1)"]
+[[package]]
+name = "gym-aloha"
+version = "0.1.0"
+description = "A gym environment for ALOHA"
+optional = true
+python-versions = "^3.10"
+files = []
+develop = false
+
+[package.dependencies]
+dm-control = "1.0.14"
+gymnasium = "^0.29.1"
+mujoco = "^2.3.7"
+
+[package.source]
+type = "git"
+url = "git@github.com:huggingface/gym-aloha.git"
+reference = "HEAD"
+resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f"
+
+[[package]]
+name = "gym-pusht"
+version = "0.1.0"
+description = "A gymnasium environment for PushT."
+optional = true
+python-versions = "^3.10"
+files = []
+develop = false
+
+[package.dependencies]
+gymnasium = "^0.29.1"
+opencv-python = "^4.9.0.80"
+pygame = "^2.5.2"
+pymunk = "^6.6.0"
+scikit-image = "^0.22.0"
+shapely = "^2.0.3"
+
+[package.source]
+type = "git"
+url = "git@github.com:huggingface/gym-pusht.git"
+reference = "HEAD"
+resolved_reference = "080d4ce4d8d3140b2fd204ed628bda14dc58ff06"
+
+[[package]]
+name = "gym-xarm"
+version = "0.1.0"
+description = "A gym environment for xArm"
+optional = true
+python-versions = "^3.10"
+files = []
+develop = false
+
+[package.dependencies]
+gymnasium = "^0.29.1"
+gymnasium-robotics = "^1.2.4"
+mujoco = "^2.3.7"
+
+[package.source]
+type = "git"
+url = "git@github.com:huggingface/gym-xarm.git"
+reference = "HEAD"
+resolved_reference = "6a88f7d63833705dfbec4b997bf36cac6b4a448c"
+
[[package]]
name = "gymnasium"
version = "0.29.1"
@@ -913,7 +977,7 @@ toy-text = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"]
name = "gymnasium-robotics"
version = "1.2.4"
description = "Robotics environments for the Gymnasium repo."
-optional = false
+optional = true
python-versions = ">=3.8"
files = [
{file = "gymnasium-robotics-1.2.4.tar.gz", hash = "sha256:d304192b066f8b800599dfbe3d9d90bba9b761ee884472bdc4d05968a8bc61cb"},
@@ -1218,7 +1282,7 @@ i18n = ["Babel (>=2.7)"]
name = "labmaze"
version = "1.0.6"
description = "LabMaze: DeepMind Lab's text maze generator."
-optional = false
+optional = true
python-versions = "*"
files = [
{file = "labmaze-1.0.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b2ddef976dfd8d992b19cfa6c633f2eba7576d759c2082da534e3f727479a84a"},
@@ -1262,7 +1326,7 @@ setuptools = "!=50.0.0"
name = "lazy-loader"
version = "0.3"
description = "lazy_loader"
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "lazy_loader-0.3-py3-none-any.whl", hash = "sha256:1e9e76ee8631e264c62ce10006718e80b2cfc74340d17d1031e0f84af7478554"},
@@ -1305,96 +1369,174 @@ files = [
[[package]]
name = "lxml"
-version = "5.1.0"
+version = "5.2.1"
description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API."
-optional = false
+optional = true
python-versions = ">=3.6"
files = [
- {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"},
- {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"},
- {file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"},
- {file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"},
- {file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2befa20a13f1a75c751f47e00929fb3433d67eb9923c2c0b364de449121f447c"},
- {file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22b7ee4c35f374e2c20337a95502057964d7e35b996b1c667b5c65c567d2252a"},
- {file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:bf8443781533b8d37b295016a4b53c1494fa9a03573c09ca5104550c138d5c05"},
- {file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"},
- {file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"},
- {file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"},
- {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:14deca1460b4b0f6b01f1ddc9557704e8b365f55c63070463f6c18619ebf964f"},
- {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"},
- {file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"},
- {file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"},
- {file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af8920ce4a55ff41167ddbc20077f5698c2e710ad3353d32a07d3264f3a2021e"},
- {file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7cfced4a069003d8913408e10ca8ed092c49a7f6cefee9bb74b6b3e860683b45"},
- {file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9e5ac3437746189a9b4121db2a7b86056ac8786b12e88838696899328fc44bb2"},
- {file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"},
- {file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"},
- {file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"},
- {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9aa543980ab1fbf1720969af1d99095a548ea42e00361e727c58a40832439114"},
- {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"},
- {file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"},
- {file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"},
- {file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16dd953fb719f0ffc5bc067428fc9e88f599e15723a85618c45847c96f11f431"},
- {file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16018f7099245157564d7148165132c70adb272fb5a17c048ba70d9cc542a1a1"},
- {file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:82cd34f1081ae4ea2ede3d52f71b7be313756e99b4b5f829f89b12da552d3aa3"},
- {file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:19a1bc898ae9f06bccb7c3e1dfd73897ecbbd2c96afe9095a6026016e5ca97b8"},
- {file = "lxml-5.1.0-cp312-cp312-win32.whl", hash = "sha256:13521a321a25c641b9ea127ef478b580b5ec82aa2e9fc076c86169d161798b01"},
- {file = "lxml-5.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:1ad17c20e3666c035db502c78b86e58ff6b5991906e55bdbef94977700c72623"},
- {file = "lxml-5.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:24ef5a4631c0b6cceaf2dbca21687e29725b7c4e171f33a8f8ce23c12558ded1"},
- {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d2900b7f5318bc7ad8631d3d40190b95ef2aa8cc59473b73b294e4a55e9f30f"},
- {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:601f4a75797d7a770daed8b42b97cd1bb1ba18bd51a9382077a6a247a12aa38d"},
- {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4b68c961b5cc402cbd99cca5eb2547e46ce77260eb705f4d117fd9c3f932b95"},
- {file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:afd825e30f8d1f521713a5669b63657bcfe5980a916c95855060048b88e1adb7"},
- {file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:262bc5f512a66b527d026518507e78c2f9c2bd9eb5c8aeeb9f0eb43fcb69dc67"},
- {file = "lxml-5.1.0-cp36-cp36m-win32.whl", hash = "sha256:e856c1c7255c739434489ec9c8aa9cdf5179785d10ff20add308b5d673bed5cd"},
- {file = "lxml-5.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:c7257171bb8d4432fe9d6fdde4d55fdbe663a63636a17f7f9aaba9bcb3153ad7"},
- {file = "lxml-5.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b9e240ae0ba96477682aa87899d94ddec1cc7926f9df29b1dd57b39e797d5ab5"},
- {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a96f02ba1bcd330807fc060ed91d1f7a20853da6dd449e5da4b09bfcc08fdcf5"},
- {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e3898ae2b58eeafedfe99e542a17859017d72d7f6a63de0f04f99c2cb125936"},
- {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61c5a7edbd7c695e54fca029ceb351fc45cd8860119a0f83e48be44e1c464862"},
- {file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3aeca824b38ca78d9ee2ab82bd9883083d0492d9d17df065ba3b94e88e4d7ee6"},
- {file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"},
- {file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"},
- {file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"},
- {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae15347a88cf8af0949a9872b57a320d2605ae069bcdf047677318bc0bba45b1"},
- {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"},
- {file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"},
- {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"},
- {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"},
- {file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:98f3f020a2b736566c707c8e034945c02aa94e124c24f77ca097c446f81b01f1"},
- {file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"},
- {file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"},
- {file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"},
- {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:52427a7eadc98f9e62cb1368a5079ae826f94f05755d2d567d93ee1bc3ceb354"},
- {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"},
- {file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"},
- {file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"},
- {file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f8b0c78e7aac24979ef09b7f50da871c2de2def043d468c4b41f512d831e912"},
- {file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9bcf86dfc8ff3e992fed847c077bd875d9e0ba2fa25d859c3a0f0f76f07f0c8d"},
- {file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:49a9b4af45e8b925e1cd6f3b15bbba2c81e7dba6dce170c677c9cda547411e14"},
- {file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:280f3edf15c2a967d923bcfb1f8f15337ad36f93525828b40a0f9d6c2ad24890"},
- {file = "lxml-5.1.0-cp39-cp39-win32.whl", hash = "sha256:ed7326563024b6e91fef6b6c7a1a2ff0a71b97793ac33dbbcf38f6005e51ff6e"},
- {file = "lxml-5.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:8d7b4beebb178e9183138f552238f7e6613162a42164233e2bda00cb3afac58f"},
- {file = "lxml-5.1.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9bd0ae7cc2b85320abd5e0abad5ccee5564ed5f0cc90245d2f9a8ef330a8deae"},
- {file = "lxml-5.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8c1d679df4361408b628f42b26a5d62bd3e9ba7f0c0e7969f925021554755aa"},
- {file = "lxml-5.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2ad3a8ce9e8a767131061a22cd28fdffa3cd2dc193f399ff7b81777f3520e372"},
- {file = "lxml-5.1.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:304128394c9c22b6569eba2a6d98392b56fbdfbad58f83ea702530be80d0f9df"},
- {file = "lxml-5.1.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d74fcaf87132ffc0447b3c685a9f862ffb5b43e70ea6beec2fb8057d5d2a1fea"},
- {file = "lxml-5.1.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:8cf5877f7ed384dabfdcc37922c3191bf27e55b498fecece9fd5c2c7aaa34c33"},
- {file = "lxml-5.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:877efb968c3d7eb2dad540b6cabf2f1d3c0fbf4b2d309a3c141f79c7e0061324"},
- {file = "lxml-5.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f14a4fb1c1c402a22e6a341a24c1341b4a3def81b41cd354386dcb795f83897"},
- {file = "lxml-5.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:25663d6e99659544ee8fe1b89b1a8c0aaa5e34b103fab124b17fa958c4a324a6"},
- {file = "lxml-5.1.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8b9f19df998761babaa7f09e6bc169294eefafd6149aaa272081cbddc7ba4ca3"},
- {file = "lxml-5.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e53d7e6a98b64fe54775d23a7c669763451340c3d44ad5e3a3b48a1efbdc96f"},
- {file = "lxml-5.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c3cd1fc1dc7c376c54440aeaaa0dcc803d2126732ff5c6b68ccd619f2e64be4f"},
- {file = "lxml-5.1.0.tar.gz", hash = "sha256:3eea6ed6e6c918e468e693c41ef07f3c3acc310b70ddd9cc72d9ef84bc9564ca"},
+ {file = "lxml-5.2.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1f7785f4f789fdb522729ae465adcaa099e2a3441519df750ebdccc481d961a1"},
+ {file = "lxml-5.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6cc6ee342fb7fa2471bd9b6d6fdfc78925a697bf5c2bcd0a302e98b0d35bfad3"},
+ {file = "lxml-5.2.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:794f04eec78f1d0e35d9e0c36cbbb22e42d370dda1609fb03bcd7aeb458c6377"},
+ {file = "lxml-5.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c817d420c60a5183953c783b0547d9eb43b7b344a2c46f69513d5952a78cddf3"},
+ {file = "lxml-5.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2213afee476546a7f37c7a9b4ad4d74b1e112a6fafffc9185d6d21f043128c81"},
+ {file = "lxml-5.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b070bbe8d3f0f6147689bed981d19bbb33070225373338df755a46893528104a"},
+ {file = "lxml-5.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e02c5175f63effbd7c5e590399c118d5db6183bbfe8e0d118bdb5c2d1b48d937"},
+ {file = "lxml-5.2.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:3dc773b2861b37b41a6136e0b72a1a44689a9c4c101e0cddb6b854016acc0aa8"},
+ {file = "lxml-5.2.1-cp310-cp310-manylinux_2_28_ppc64le.whl", hash = "sha256:d7520db34088c96cc0e0a3ad51a4fd5b401f279ee112aa2b7f8f976d8582606d"},
+ {file = "lxml-5.2.1-cp310-cp310-manylinux_2_28_s390x.whl", hash = "sha256:bcbf4af004f98793a95355980764b3d80d47117678118a44a80b721c9913436a"},
+ {file = "lxml-5.2.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a2b44bec7adf3e9305ce6cbfa47a4395667e744097faed97abb4728748ba7d47"},
+ {file = "lxml-5.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:1c5bb205e9212d0ebddf946bc07e73fa245c864a5f90f341d11ce7b0b854475d"},
+ {file = "lxml-5.2.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2c9d147f754b1b0e723e6afb7ba1566ecb162fe4ea657f53d2139bbf894d050a"},
+ {file = "lxml-5.2.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:3545039fa4779be2df51d6395e91a810f57122290864918b172d5dc7ca5bb433"},
+ {file = "lxml-5.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a91481dbcddf1736c98a80b122afa0f7296eeb80b72344d7f45dc9f781551f56"},
+ {file = "lxml-5.2.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2ddfe41ddc81f29a4c44c8ce239eda5ade4e7fc305fb7311759dd6229a080052"},
+ {file = "lxml-5.2.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:a7baf9ffc238e4bf401299f50e971a45bfcc10a785522541a6e3179c83eabf0a"},
+ {file = "lxml-5.2.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:31e9a882013c2f6bd2f2c974241bf4ba68c85eba943648ce88936d23209a2e01"},
+ {file = "lxml-5.2.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:0a15438253b34e6362b2dc41475e7f80de76320f335e70c5528b7148cac253a1"},
+ {file = "lxml-5.2.1-cp310-cp310-win32.whl", hash = "sha256:6992030d43b916407c9aa52e9673612ff39a575523c5f4cf72cdef75365709a5"},
+ {file = "lxml-5.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:da052e7962ea2d5e5ef5bc0355d55007407087392cf465b7ad84ce5f3e25fe0f"},
+ {file = "lxml-5.2.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:70ac664a48aa64e5e635ae5566f5227f2ab7f66a3990d67566d9907edcbbf867"},
+ {file = "lxml-5.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1ae67b4e737cddc96c99461d2f75d218bdf7a0c3d3ad5604d1f5e7464a2f9ffe"},
+ {file = "lxml-5.2.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f18a5a84e16886898e51ab4b1d43acb3083c39b14c8caeb3589aabff0ee0b270"},
+ {file = "lxml-5.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6f2c8372b98208ce609c9e1d707f6918cc118fea4e2c754c9f0812c04ca116d"},
+ {file = "lxml-5.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:394ed3924d7a01b5bd9a0d9d946136e1c2f7b3dc337196d99e61740ed4bc6fe1"},
+ {file = "lxml-5.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d077bc40a1fe984e1a9931e801e42959a1e6598edc8a3223b061d30fbd26bbc"},
+ {file = "lxml-5.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:764b521b75701f60683500d8621841bec41a65eb739b8466000c6fdbc256c240"},
+ {file = "lxml-5.2.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:3a6b45da02336895da82b9d472cd274b22dc27a5cea1d4b793874eead23dd14f"},
+ {file = "lxml-5.2.1-cp311-cp311-manylinux_2_28_ppc64le.whl", hash = "sha256:5ea7b6766ac2dfe4bcac8b8595107665a18ef01f8c8343f00710b85096d1b53a"},
+ {file = "lxml-5.2.1-cp311-cp311-manylinux_2_28_s390x.whl", hash = "sha256:e196a4ff48310ba62e53a8e0f97ca2bca83cdd2fe2934d8b5cb0df0a841b193a"},
+ {file = "lxml-5.2.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:200e63525948e325d6a13a76ba2911f927ad399ef64f57898cf7c74e69b71095"},
+ {file = "lxml-5.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:dae0ed02f6b075426accbf6b2863c3d0a7eacc1b41fb40f2251d931e50188dad"},
+ {file = "lxml-5.2.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:ab31a88a651039a07a3ae327d68ebdd8bc589b16938c09ef3f32a4b809dc96ef"},
+ {file = "lxml-5.2.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:df2e6f546c4df14bc81f9498bbc007fbb87669f1bb707c6138878c46b06f6510"},
+ {file = "lxml-5.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5dd1537e7cc06efd81371f5d1a992bd5ab156b2b4f88834ca852de4a8ea523fa"},
+ {file = "lxml-5.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9b9ec9c9978b708d488bec36b9e4c94d88fd12ccac3e62134a9d17ddba910ea9"},
+ {file = "lxml-5.2.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:8e77c69d5892cb5ba71703c4057091e31ccf534bd7f129307a4d084d90d014b8"},
+ {file = "lxml-5.2.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:a8d5c70e04aac1eda5c829a26d1f75c6e5286c74743133d9f742cda8e53b9c2f"},
+ {file = "lxml-5.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c94e75445b00319c1fad60f3c98b09cd63fe1134a8a953dcd48989ef42318534"},
+ {file = "lxml-5.2.1-cp311-cp311-win32.whl", hash = "sha256:4951e4f7a5680a2db62f7f4ab2f84617674d36d2d76a729b9a8be4b59b3659be"},
+ {file = "lxml-5.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:5c670c0406bdc845b474b680b9a5456c561c65cf366f8db5a60154088c92d102"},
+ {file = "lxml-5.2.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:abc25c3cab9ec7fcd299b9bcb3b8d4a1231877e425c650fa1c7576c5107ab851"},
+ {file = "lxml-5.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6935bbf153f9a965f1e07c2649c0849d29832487c52bb4a5c5066031d8b44fd5"},
+ {file = "lxml-5.2.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d793bebb202a6000390a5390078e945bbb49855c29c7e4d56a85901326c3b5d9"},
+ {file = "lxml-5.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afd5562927cdef7c4f5550374acbc117fd4ecc05b5007bdfa57cc5355864e0a4"},
+ {file = "lxml-5.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0e7259016bc4345a31af861fdce942b77c99049d6c2107ca07dc2bba2435c1d9"},
+ {file = "lxml-5.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:530e7c04f72002d2f334d5257c8a51bf409db0316feee7c87e4385043be136af"},
+ {file = "lxml-5.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59689a75ba8d7ffca577aefd017d08d659d86ad4585ccc73e43edbfc7476781a"},
+ {file = "lxml-5.2.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:f9737bf36262046213a28e789cc82d82c6ef19c85a0cf05e75c670a33342ac2c"},
+ {file = "lxml-5.2.1-cp312-cp312-manylinux_2_28_ppc64le.whl", hash = "sha256:3a74c4f27167cb95c1d4af1c0b59e88b7f3e0182138db2501c353555f7ec57f4"},
+ {file = "lxml-5.2.1-cp312-cp312-manylinux_2_28_s390x.whl", hash = "sha256:68a2610dbe138fa8c5826b3f6d98a7cfc29707b850ddcc3e21910a6fe51f6ca0"},
+ {file = "lxml-5.2.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:f0a1bc63a465b6d72569a9bba9f2ef0334c4e03958e043da1920299100bc7c08"},
+ {file = "lxml-5.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c2d35a1d047efd68027817b32ab1586c1169e60ca02c65d428ae815b593e65d4"},
+ {file = "lxml-5.2.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:79bd05260359170f78b181b59ce871673ed01ba048deef4bf49a36ab3e72e80b"},
+ {file = "lxml-5.2.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:865bad62df277c04beed9478fe665b9ef63eb28fe026d5dedcb89b537d2e2ea6"},
+ {file = "lxml-5.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:44f6c7caff88d988db017b9b0e4ab04934f11e3e72d478031efc7edcac6c622f"},
+ {file = "lxml-5.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:71e97313406ccf55d32cc98a533ee05c61e15d11b99215b237346171c179c0b0"},
+ {file = "lxml-5.2.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:057cdc6b86ab732cf361f8b4d8af87cf195a1f6dc5b0ff3de2dced242c2015e0"},
+ {file = "lxml-5.2.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:f3bbbc998d42f8e561f347e798b85513ba4da324c2b3f9b7969e9c45b10f6169"},
+ {file = "lxml-5.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:491755202eb21a5e350dae00c6d9a17247769c64dcf62d8c788b5c135e179dc4"},
+ {file = "lxml-5.2.1-cp312-cp312-win32.whl", hash = "sha256:8de8f9d6caa7f25b204fc861718815d41cbcf27ee8f028c89c882a0cf4ae4134"},
+ {file = "lxml-5.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:f2a9efc53d5b714b8df2b4b3e992accf8ce5bbdfe544d74d5c6766c9e1146a3a"},
+ {file = "lxml-5.2.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:70a9768e1b9d79edca17890175ba915654ee1725975d69ab64813dd785a2bd5c"},
+ {file = "lxml-5.2.1-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c38d7b9a690b090de999835f0443d8aa93ce5f2064035dfc48f27f02b4afc3d0"},
+ {file = "lxml-5.2.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5670fb70a828663cc37552a2a85bf2ac38475572b0e9b91283dc09efb52c41d1"},
+ {file = "lxml-5.2.1-cp36-cp36m-manylinux_2_28_x86_64.whl", hash = "sha256:958244ad566c3ffc385f47dddde4145088a0ab893504b54b52c041987a8c1863"},
+ {file = "lxml-5.2.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b6241d4eee5f89453307c2f2bfa03b50362052ca0af1efecf9fef9a41a22bb4f"},
+ {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:2a66bf12fbd4666dd023b6f51223aed3d9f3b40fef06ce404cb75bafd3d89536"},
+ {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:9123716666e25b7b71c4e1789ec829ed18663152008b58544d95b008ed9e21e9"},
+ {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:0c3f67e2aeda739d1cc0b1102c9a9129f7dc83901226cc24dd72ba275ced4218"},
+ {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:5d5792e9b3fb8d16a19f46aa8208987cfeafe082363ee2745ea8b643d9cc5b45"},
+ {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:88e22fc0a6684337d25c994381ed8a1580a6f5ebebd5ad41f89f663ff4ec2885"},
+ {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:21c2e6b09565ba5b45ae161b438e033a86ad1736b8c838c766146eff8ceffff9"},
+ {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_2_s390x.whl", hash = "sha256:afbbdb120d1e78d2ba8064a68058001b871154cc57787031b645c9142b937a62"},
+ {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:627402ad8dea044dde2eccde4370560a2b750ef894c9578e1d4f8ffd54000461"},
+ {file = "lxml-5.2.1-cp36-cp36m-win32.whl", hash = "sha256:e89580a581bf478d8dcb97d9cd011d567768e8bc4095f8557b21c4d4c5fea7d0"},
+ {file = "lxml-5.2.1-cp36-cp36m-win_amd64.whl", hash = "sha256:59565f10607c244bc4c05c0c5fa0c190c990996e0c719d05deec7030c2aa8289"},
+ {file = "lxml-5.2.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:857500f88b17a6479202ff5fe5f580fc3404922cd02ab3716197adf1ef628029"},
+ {file = "lxml-5.2.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:56c22432809085b3f3ae04e6e7bdd36883d7258fcd90e53ba7b2e463efc7a6af"},
+ {file = "lxml-5.2.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a55ee573116ba208932e2d1a037cc4b10d2c1cb264ced2184d00b18ce585b2c0"},
+ {file = "lxml-5.2.1-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:6cf58416653c5901e12624e4013708b6e11142956e7f35e7a83f1ab02f3fe456"},
+ {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:64c2baa7774bc22dd4474248ba16fe1a7f611c13ac6123408694d4cc93d66dbd"},
+ {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:74b28c6334cca4dd704e8004cba1955af0b778cf449142e581e404bd211fb619"},
+ {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:7221d49259aa1e5a8f00d3d28b1e0b76031655ca74bb287123ef56c3db92f213"},
+ {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3dbe858ee582cbb2c6294dc85f55b5f19c918c2597855e950f34b660f1a5ede6"},
+ {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:04ab5415bf6c86e0518d57240a96c4d1fcfc3cb370bb2ac2a732b67f579e5a04"},
+ {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:6ab833e4735a7e5533711a6ea2df26459b96f9eec36d23f74cafe03631647c41"},
+ {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:f443cdef978430887ed55112b491f670bba6462cea7a7742ff8f14b7abb98d75"},
+ {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:9e2addd2d1866fe112bc6f80117bcc6bc25191c5ed1bfbcf9f1386a884252ae8"},
+ {file = "lxml-5.2.1-cp37-cp37m-win32.whl", hash = "sha256:f51969bac61441fd31f028d7b3b45962f3ecebf691a510495e5d2cd8c8092dbd"},
+ {file = "lxml-5.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:b0b58fbfa1bf7367dde8a557994e3b1637294be6cf2169810375caf8571a085c"},
+ {file = "lxml-5.2.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3e183c6e3298a2ed5af9d7a356ea823bccaab4ec2349dc9ed83999fd289d14d5"},
+ {file = "lxml-5.2.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:804f74efe22b6a227306dd890eecc4f8c59ff25ca35f1f14e7482bbce96ef10b"},
+ {file = "lxml-5.2.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:08802f0c56ed150cc6885ae0788a321b73505d2263ee56dad84d200cab11c07a"},
+ {file = "lxml-5.2.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8c09ed18ecb4ebf23e02b8e7a22a05d6411911e6fabef3a36e4f371f4f2585"},
+ {file = "lxml-5.2.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3d30321949861404323c50aebeb1943461a67cd51d4200ab02babc58bd06a86"},
+ {file = "lxml-5.2.1-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:b560e3aa4b1d49e0e6c847d72665384db35b2f5d45f8e6a5c0072e0283430533"},
+ {file = "lxml-5.2.1-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:058a1308914f20784c9f4674036527e7c04f7be6fb60f5d61353545aa7fcb739"},
+ {file = "lxml-5.2.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:adfb84ca6b87e06bc6b146dc7da7623395db1e31621c4785ad0658c5028b37d7"},
+ {file = "lxml-5.2.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:417d14450f06d51f363e41cace6488519038f940676ce9664b34ebf5653433a5"},
+ {file = "lxml-5.2.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a2dfe7e2473f9b59496247aad6e23b405ddf2e12ef0765677b0081c02d6c2c0b"},
+ {file = "lxml-5.2.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bf2e2458345d9bffb0d9ec16557d8858c9c88d2d11fed53998512504cd9df49b"},
+ {file = "lxml-5.2.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:58278b29cb89f3e43ff3e0c756abbd1518f3ee6adad9e35b51fb101c1c1daaec"},
+ {file = "lxml-5.2.1-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:64641a6068a16201366476731301441ce93457eb8452056f570133a6ceb15fca"},
+ {file = "lxml-5.2.1-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:78bfa756eab503673991bdcf464917ef7845a964903d3302c5f68417ecdc948c"},
+ {file = "lxml-5.2.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:11a04306fcba10cd9637e669fd73aa274c1c09ca64af79c041aa820ea992b637"},
+ {file = "lxml-5.2.1-cp38-cp38-win32.whl", hash = "sha256:66bc5eb8a323ed9894f8fa0ee6cb3e3fb2403d99aee635078fd19a8bc7a5a5da"},
+ {file = "lxml-5.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:9676bfc686fa6a3fa10cd4ae6b76cae8be26eb5ec6811d2a325636c460da1806"},
+ {file = "lxml-5.2.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cf22b41fdae514ee2f1691b6c3cdeae666d8b7fa9434de445f12bbeee0cf48dd"},
+ {file = "lxml-5.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ec42088248c596dbd61d4ae8a5b004f97a4d91a9fd286f632e42e60b706718d7"},
+ {file = "lxml-5.2.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd53553ddad4a9c2f1f022756ae64abe16da1feb497edf4d9f87f99ec7cf86bd"},
+ {file = "lxml-5.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feaa45c0eae424d3e90d78823f3828e7dc42a42f21ed420db98da2c4ecf0a2cb"},
+ {file = "lxml-5.2.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddc678fb4c7e30cf830a2b5a8d869538bc55b28d6c68544d09c7d0d8f17694dc"},
+ {file = "lxml-5.2.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:853e074d4931dbcba7480d4dcab23d5c56bd9607f92825ab80ee2bd916edea53"},
+ {file = "lxml-5.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc4691d60512798304acb9207987e7b2b7c44627ea88b9d77489bbe3e6cc3bd4"},
+ {file = "lxml-5.2.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:beb72935a941965c52990f3a32d7f07ce869fe21c6af8b34bf6a277b33a345d3"},
+ {file = "lxml-5.2.1-cp39-cp39-manylinux_2_28_ppc64le.whl", hash = "sha256:6588c459c5627fefa30139be4d2e28a2c2a1d0d1c265aad2ba1935a7863a4913"},
+ {file = "lxml-5.2.1-cp39-cp39-manylinux_2_28_s390x.whl", hash = "sha256:588008b8497667f1ddca7c99f2f85ce8511f8f7871b4a06ceede68ab62dff64b"},
+ {file = "lxml-5.2.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b6787b643356111dfd4032b5bffe26d2f8331556ecb79e15dacb9275da02866e"},
+ {file = "lxml-5.2.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7c17b64b0a6ef4e5affae6a3724010a7a66bda48a62cfe0674dabd46642e8b54"},
+ {file = "lxml-5.2.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:27aa20d45c2e0b8cd05da6d4759649170e8dfc4f4e5ef33a34d06f2d79075d57"},
+ {file = "lxml-5.2.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:d4f2cc7060dc3646632d7f15fe68e2fa98f58e35dd5666cd525f3b35d3fed7f8"},
+ {file = "lxml-5.2.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff46d772d5f6f73564979cd77a4fffe55c916a05f3cb70e7c9c0590059fb29ef"},
+ {file = "lxml-5.2.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:96323338e6c14e958d775700ec8a88346014a85e5de73ac7967db0367582049b"},
+ {file = "lxml-5.2.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:52421b41ac99e9d91934e4d0d0fe7da9f02bfa7536bb4431b4c05c906c8c6919"},
+ {file = "lxml-5.2.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:7a7efd5b6d3e30d81ec68ab8a88252d7c7c6f13aaa875009fe3097eb4e30b84c"},
+ {file = "lxml-5.2.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:0ed777c1e8c99b63037b91f9d73a6aad20fd035d77ac84afcc205225f8f41188"},
+ {file = "lxml-5.2.1-cp39-cp39-win32.whl", hash = "sha256:644df54d729ef810dcd0f7732e50e5ad1bd0a135278ed8d6bcb06f33b6b6f708"},
+ {file = "lxml-5.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:9ca66b8e90daca431b7ca1408cae085d025326570e57749695d6a01454790e95"},
+ {file = "lxml-5.2.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9b0ff53900566bc6325ecde9181d89afadc59c5ffa39bddf084aaedfe3b06a11"},
+ {file = "lxml-5.2.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd6037392f2d57793ab98d9e26798f44b8b4da2f2464388588f48ac52c489ea1"},
+ {file = "lxml-5.2.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b9c07e7a45bb64e21df4b6aa623cb8ba214dfb47d2027d90eac197329bb5e94"},
+ {file = "lxml-5.2.1-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:3249cc2989d9090eeac5467e50e9ec2d40704fea9ab72f36b034ea34ee65ca98"},
+ {file = "lxml-5.2.1-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:f42038016852ae51b4088b2862126535cc4fc85802bfe30dea3500fdfaf1864e"},
+ {file = "lxml-5.2.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:533658f8fbf056b70e434dff7e7aa611bcacb33e01f75de7f821810e48d1bb66"},
+ {file = "lxml-5.2.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:622020d4521e22fb371e15f580d153134bfb68d6a429d1342a25f051ec72df1c"},
+ {file = "lxml-5.2.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:efa7b51824aa0ee957ccd5a741c73e6851de55f40d807f08069eb4c5a26b2baa"},
+ {file = "lxml-5.2.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c6ad0fbf105f6bcc9300c00010a2ffa44ea6f555df1a2ad95c88f5656104817"},
+ {file = "lxml-5.2.1-pp37-pypy37_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:e233db59c8f76630c512ab4a4daf5a5986da5c3d5b44b8e9fc742f2a24dbd460"},
+ {file = "lxml-5.2.1-pp37-pypy37_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:6a014510830df1475176466b6087fc0c08b47a36714823e58d8b8d7709132a96"},
+ {file = "lxml-5.2.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:d38c8f50ecf57f0463399569aa388b232cf1a2ffb8f0a9a5412d0db57e054860"},
+ {file = "lxml-5.2.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5aea8212fb823e006b995c4dda533edcf98a893d941f173f6c9506126188860d"},
+ {file = "lxml-5.2.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ff097ae562e637409b429a7ac958a20aab237a0378c42dabaa1e3abf2f896e5f"},
+ {file = "lxml-5.2.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f5d65c39f16717a47c36c756af0fb36144069c4718824b7533f803ecdf91138"},
+ {file = "lxml-5.2.1-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:3d0c3dd24bb4605439bf91068598d00c6370684f8de4a67c2992683f6c309d6b"},
+ {file = "lxml-5.2.1-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e32be23d538753a8adb6c85bd539f5fd3b15cb987404327c569dfc5fd8366e85"},
+ {file = "lxml-5.2.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:cc518cea79fd1e2f6c90baafa28906d4309d24f3a63e801d855e7424c5b34144"},
+ {file = "lxml-5.2.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a0af35bd8ebf84888373630f73f24e86bf016642fb8576fba49d3d6b560b7cbc"},
+ {file = "lxml-5.2.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8aca2e3a72f37bfc7b14ba96d4056244001ddcc18382bd0daa087fd2e68a354"},
+ {file = "lxml-5.2.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ca1e8188b26a819387b29c3895c47a5e618708fe6f787f3b1a471de2c4a94d9"},
+ {file = "lxml-5.2.1-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c8ba129e6d3b0136a0f50345b2cb3db53f6bda5dd8c7f5d83fbccba97fb5dcb5"},
+ {file = "lxml-5.2.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e998e304036198b4f6914e6a1e2b6f925208a20e2042563d9734881150c6c246"},
+ {file = "lxml-5.2.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:d3be9b2076112e51b323bdf6d5a7f8a798de55fb8d95fcb64bd179460cdc0704"},
+ {file = "lxml-5.2.1.tar.gz", hash = "sha256:3f7765e69bbce0906a7c74d5fe46d2c7a7596147318dbc08e4a2431f3060e306"},
]
[package.extras]
cssselect = ["cssselect (>=0.7)"]
+html-clean = ["lxml-html-clean"]
html5 = ["html5lib"]
htmlsoup = ["BeautifulSoup4"]
-source = ["Cython (>=3.0.7)"]
+source = ["Cython (>=3.0.10)"]
[[package]]
name = "markdown"
@@ -1525,7 +1667,7 @@ tests = ["pytest (>=4.6)"]
name = "mujoco"
version = "2.3.7"
description = "MuJoCo Physics Simulator"
-optional = false
+optional = true
python-versions = ">=3.8"
files = [
{file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"},
@@ -1563,20 +1705,20 @@ pyopengl = "*"
[[package]]
name = "networkx"
-version = "3.2.1"
+version = "3.3"
description = "Python package for creating and manipulating graphs and networks"
optional = false
-python-versions = ">=3.9"
+python-versions = ">=3.10"
files = [
- {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"},
- {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"},
+ {file = "networkx-3.3-py3-none-any.whl", hash = "sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2"},
+ {file = "networkx-3.3.tar.gz", hash = "sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9"},
]
[package.extras]
-default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"]
-developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"]
-doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"]
-extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"]
+default = ["matplotlib (>=3.6)", "numpy (>=1.23)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"]
+developer = ["changelist (==0.5)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"]
+doc = ["myst-nb (>=1.0)", "numpydoc (>=1.7)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"]
+extra = ["lxml (>=4.6)", "pydot (>=2.0)", "pygraphviz (>=1.12)", "sympy (>=1.10)"]
test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"]
[[package]]
@@ -1833,13 +1975,13 @@ files = [
[[package]]
name = "nvidia-nvjitlink-cu12"
-version = "12.4.99"
+version = "12.4.127"
description = "Nvidia JIT LTO Library"
optional = false
python-versions = ">=3"
files = [
- {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"},
- {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"},
+ {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
+ {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
]
[[package]]
@@ -1980,7 +2122,7 @@ xml = ["lxml (>=4.9.2)"]
name = "pettingzoo"
version = "1.24.3"
description = "Gymnasium for multi-agent reinforcement learning."
-optional = false
+optional = true
python-versions = ">=3.8"
files = [
{file = "pettingzoo-1.24.3-py3-none-any.whl", hash = "sha256:23ed90517d2e8a7098bdaf5e31234b3a7f7b73ca578d70d1ca7b9d0cb0e37982"},
@@ -2003,79 +2145,80 @@ testing = ["AutoROM", "pre-commit", "pynput", "pytest", "pytest-cov", "pytest-ma
[[package]]
name = "pillow"
-version = "10.2.0"
+version = "10.3.0"
description = "Python Imaging Library (Fork)"
optional = false
python-versions = ">=3.8"
files = [
- {file = "pillow-10.2.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:7823bdd049099efa16e4246bdf15e5a13dbb18a51b68fa06d6c1d4d8b99a796e"},
- {file = "pillow-10.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:83b2021f2ade7d1ed556bc50a399127d7fb245e725aa0113ebd05cfe88aaf588"},
- {file = "pillow-10.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6fad5ff2f13d69b7e74ce5b4ecd12cc0ec530fcee76356cac6742785ff71c452"},
- {file = "pillow-10.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da2b52b37dad6d9ec64e653637a096905b258d2fc2b984c41ae7d08b938a67e4"},
- {file = "pillow-10.2.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:47c0995fc4e7f79b5cfcab1fc437ff2890b770440f7696a3ba065ee0fd496563"},
- {file = "pillow-10.2.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:322bdf3c9b556e9ffb18f93462e5f749d3444ce081290352c6070d014c93feb2"},
- {file = "pillow-10.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:51f1a1bffc50e2e9492e87d8e09a17c5eea8409cda8d3f277eb6edc82813c17c"},
- {file = "pillow-10.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:69ffdd6120a4737710a9eee73e1d2e37db89b620f702754b8f6e62594471dee0"},
- {file = "pillow-10.2.0-cp310-cp310-win32.whl", hash = "sha256:c6dafac9e0f2b3c78df97e79af707cdc5ef8e88208d686a4847bab8266870023"},
- {file = "pillow-10.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:aebb6044806f2e16ecc07b2a2637ee1ef67a11840a66752751714a0d924adf72"},
- {file = "pillow-10.2.0-cp310-cp310-win_arm64.whl", hash = "sha256:7049e301399273a0136ff39b84c3678e314f2158f50f517bc50285fb5ec847ad"},
- {file = "pillow-10.2.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:35bb52c37f256f662abdfa49d2dfa6ce5d93281d323a9af377a120e89a9eafb5"},
- {file = "pillow-10.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9c23f307202661071d94b5e384e1e1dc7dfb972a28a2310e4ee16103e66ddb67"},
- {file = "pillow-10.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:773efe0603db30c281521a7c0214cad7836c03b8ccff897beae9b47c0b657d61"},
- {file = "pillow-10.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11fa2e5984b949b0dd6d7a94d967743d87c577ff0b83392f17cb3990d0d2fd6e"},
- {file = "pillow-10.2.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:716d30ed977be8b37d3ef185fecb9e5a1d62d110dfbdcd1e2a122ab46fddb03f"},
- {file = "pillow-10.2.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a086c2af425c5f62a65e12fbf385f7c9fcb8f107d0849dba5839461a129cf311"},
- {file = "pillow-10.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c8de2789052ed501dd829e9cae8d3dcce7acb4777ea4a479c14521c942d395b1"},
- {file = "pillow-10.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:609448742444d9290fd687940ac0b57fb35e6fd92bdb65386e08e99af60bf757"},
- {file = "pillow-10.2.0-cp311-cp311-win32.whl", hash = "sha256:823ef7a27cf86df6597fa0671066c1b596f69eba53efa3d1e1cb8b30f3533068"},
- {file = "pillow-10.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:1da3b2703afd040cf65ec97efea81cfba59cdbed9c11d8efc5ab09df9509fc56"},
- {file = "pillow-10.2.0-cp311-cp311-win_arm64.whl", hash = "sha256:edca80cbfb2b68d7b56930b84a0e45ae1694aeba0541f798e908a49d66b837f1"},
- {file = "pillow-10.2.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:1b5e1b74d1bd1b78bc3477528919414874748dd363e6272efd5abf7654e68bef"},
- {file = "pillow-10.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0eae2073305f451d8ecacb5474997c08569fb4eb4ac231ffa4ad7d342fdc25ac"},
- {file = "pillow-10.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7c2286c23cd350b80d2fc9d424fc797575fb16f854b831d16fd47ceec078f2c"},
- {file = "pillow-10.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e23412b5c41e58cec602f1135c57dfcf15482013ce6e5f093a86db69646a5aa"},
- {file = "pillow-10.2.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:52a50aa3fb3acb9cf7213573ef55d31d6eca37f5709c69e6858fe3bc04a5c2a2"},
- {file = "pillow-10.2.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:127cee571038f252a552760076407f9cff79761c3d436a12af6000cd182a9d04"},
- {file = "pillow-10.2.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:8d12251f02d69d8310b046e82572ed486685c38f02176bd08baf216746eb947f"},
- {file = "pillow-10.2.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:54f1852cd531aa981bc0965b7d609f5f6cc8ce8c41b1139f6ed6b3c54ab82bfb"},
- {file = "pillow-10.2.0-cp312-cp312-win32.whl", hash = "sha256:257d8788df5ca62c980314053197f4d46eefedf4e6175bc9412f14412ec4ea2f"},
- {file = "pillow-10.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:154e939c5f0053a383de4fd3d3da48d9427a7e985f58af8e94d0b3c9fcfcf4f9"},
- {file = "pillow-10.2.0-cp312-cp312-win_arm64.whl", hash = "sha256:f379abd2f1e3dddb2b61bc67977a6b5a0a3f7485538bcc6f39ec76163891ee48"},
- {file = "pillow-10.2.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:8373c6c251f7ef8bda6675dd6d2b3a0fcc31edf1201266b5cf608b62a37407f9"},
- {file = "pillow-10.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:870ea1ada0899fd0b79643990809323b389d4d1d46c192f97342eeb6ee0b8483"},
- {file = "pillow-10.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4b6b1e20608493548b1f32bce8cca185bf0480983890403d3b8753e44077129"},
- {file = "pillow-10.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3031709084b6e7852d00479fd1d310b07d0ba82765f973b543c8af5061cf990e"},
- {file = "pillow-10.2.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:3ff074fc97dd4e80543a3e91f69d58889baf2002b6be64347ea8cf5533188213"},
- {file = "pillow-10.2.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:cb4c38abeef13c61d6916f264d4845fab99d7b711be96c326b84df9e3e0ff62d"},
- {file = "pillow-10.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b1b3020d90c2d8e1dae29cf3ce54f8094f7938460fb5ce8bc5c01450b01fbaf6"},
- {file = "pillow-10.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:170aeb00224ab3dc54230c797f8404507240dd868cf52066f66a41b33169bdbe"},
- {file = "pillow-10.2.0-cp38-cp38-win32.whl", hash = "sha256:c4225f5220f46b2fde568c74fca27ae9771536c2e29d7c04f4fb62c83275ac4e"},
- {file = "pillow-10.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:0689b5a8c5288bc0504d9fcee48f61a6a586b9b98514d7d29b840143d6734f39"},
- {file = "pillow-10.2.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:b792a349405fbc0163190fde0dc7b3fef3c9268292586cf5645598b48e63dc67"},
- {file = "pillow-10.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c570f24be1e468e3f0ce7ef56a89a60f0e05b30a3669a459e419c6eac2c35364"},
- {file = "pillow-10.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8ecd059fdaf60c1963c58ceb8997b32e9dc1b911f5da5307aab614f1ce5c2fb"},
- {file = "pillow-10.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c365fd1703040de1ec284b176d6af5abe21b427cb3a5ff68e0759e1e313a5e7e"},
- {file = "pillow-10.2.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:70c61d4c475835a19b3a5aa42492409878bbca7438554a1f89d20d58a7c75c01"},
- {file = "pillow-10.2.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b6f491cdf80ae540738859d9766783e3b3c8e5bd37f5dfa0b76abdecc5081f13"},
- {file = "pillow-10.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d189550615b4948f45252d7f005e53c2040cea1af5b60d6f79491a6e147eef7"},
- {file = "pillow-10.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:49d9ba1ed0ef3e061088cd1e7538a0759aab559e2e0a80a36f9fd9d8c0c21591"},
- {file = "pillow-10.2.0-cp39-cp39-win32.whl", hash = "sha256:babf5acfede515f176833ed6028754cbcd0d206f7f614ea3447d67c33be12516"},
- {file = "pillow-10.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:0304004f8067386b477d20a518b50f3fa658a28d44e4116970abfcd94fac34a8"},
- {file = "pillow-10.2.0-cp39-cp39-win_arm64.whl", hash = "sha256:0fb3e7fc88a14eacd303e90481ad983fd5b69c761e9e6ef94c983f91025da869"},
- {file = "pillow-10.2.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:322209c642aabdd6207517e9739c704dc9f9db943015535783239022002f054a"},
- {file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eedd52442c0a5ff4f887fab0c1c0bb164d8635b32c894bc1faf4c618dd89df2"},
- {file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb28c753fd5eb3dd859b4ee95de66cc62af91bcff5db5f2571d32a520baf1f04"},
- {file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:33870dc4653c5017bf4c8873e5488d8f8d5f8935e2f1fb9a2208c47cdd66efd2"},
- {file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:3c31822339516fb3c82d03f30e22b1d038da87ef27b6a78c9549888f8ceda39a"},
- {file = "pillow-10.2.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a2b56ba36e05f973d450582fb015594aaa78834fefe8dfb8fcd79b93e64ba4c6"},
- {file = "pillow-10.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:d8e6aeb9201e655354b3ad049cb77d19813ad4ece0df1249d3c793de3774f8c7"},
- {file = "pillow-10.2.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:2247178effb34a77c11c0e8ac355c7a741ceca0a732b27bf11e747bbc950722f"},
- {file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15587643b9e5eb26c48e49a7b33659790d28f190fc514a322d55da2fb5c2950e"},
- {file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753cd8f2086b2b80180d9b3010dd4ed147efc167c90d3bf593fe2af21265e5a5"},
- {file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:7c8f97e8e7a9009bcacbe3766a36175056c12f9a44e6e6f2d5caad06dcfbf03b"},
- {file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d1b35bcd6c5543b9cb547dee3150c93008f8dd0f1fef78fc0cd2b141c5baf58a"},
- {file = "pillow-10.2.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fe4c15f6c9285dc54ce6553a3ce908ed37c8f3825b5a51a15c91442bb955b868"},
- {file = "pillow-10.2.0.tar.gz", hash = "sha256:e87f0b2c78157e12d7686b27d63c070fd65d994e8ddae6f328e0dcf4a0cd007e"},
+ {file = "pillow-10.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45"},
+ {file = "pillow-10.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c"},
+ {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf"},
+ {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599"},
+ {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475"},
+ {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf"},
+ {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3"},
+ {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5"},
+ {file = "pillow-10.3.0-cp310-cp310-win32.whl", hash = "sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2"},
+ {file = "pillow-10.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f"},
+ {file = "pillow-10.3.0-cp310-cp310-win_arm64.whl", hash = "sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b"},
+ {file = "pillow-10.3.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795"},
+ {file = "pillow-10.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57"},
+ {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27"},
+ {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994"},
+ {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451"},
+ {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd"},
+ {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad"},
+ {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c"},
+ {file = "pillow-10.3.0-cp311-cp311-win32.whl", hash = "sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09"},
+ {file = "pillow-10.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d"},
+ {file = "pillow-10.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f"},
+ {file = "pillow-10.3.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84"},
+ {file = "pillow-10.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19"},
+ {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338"},
+ {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1"},
+ {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462"},
+ {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a"},
+ {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef"},
+ {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3"},
+ {file = "pillow-10.3.0-cp312-cp312-win32.whl", hash = "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d"},
+ {file = "pillow-10.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b"},
+ {file = "pillow-10.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a"},
+ {file = "pillow-10.3.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b"},
+ {file = "pillow-10.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2"},
+ {file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa"},
+ {file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383"},
+ {file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d"},
+ {file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd"},
+ {file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d"},
+ {file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3"},
+ {file = "pillow-10.3.0-cp38-cp38-win32.whl", hash = "sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b"},
+ {file = "pillow-10.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999"},
+ {file = "pillow-10.3.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936"},
+ {file = "pillow-10.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002"},
+ {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60"},
+ {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375"},
+ {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57"},
+ {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8"},
+ {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9"},
+ {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb"},
+ {file = "pillow-10.3.0-cp39-cp39-win32.whl", hash = "sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572"},
+ {file = "pillow-10.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb"},
+ {file = "pillow-10.3.0-cp39-cp39-win_arm64.whl", hash = "sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591"},
+ {file = "pillow-10.3.0.tar.gz", hash = "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d"},
]
[package.extras]
@@ -2118,13 +2261,13 @@ testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "pre-commit"
-version = "3.6.2"
+version = "3.7.0"
description = "A framework for managing and maintaining multi-language pre-commit hooks."
optional = false
python-versions = ">=3.9"
files = [
- {file = "pre_commit-3.6.2-py2.py3-none-any.whl", hash = "sha256:ba637c2d7a670c10daedc059f5c49b5bd0aadbccfcd7ec15592cf9665117532c"},
- {file = "pre_commit-3.6.2.tar.gz", hash = "sha256:c3ef34f463045c88658c5b99f38c1e297abdcc0ff13f98d3370055fbbfabc67e"},
+ {file = "pre_commit-3.7.0-py2.py3-none-any.whl", hash = "sha256:5eae9e10c2b5ac51577c3452ec0a490455c45a0533f7960f993a0d01e59decab"},
+ {file = "pre_commit-3.7.0.tar.gz", hash = "sha256:e209d61b8acdcf742404408531f0c37d49d2c734fd7cff2d6076083d191cb060"},
]
[package.dependencies]
@@ -2198,13 +2341,13 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
[[package]]
name = "pycparser"
-version = "2.21"
+version = "2.22"
description = "C parser in Python"
optional = false
-python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+python-versions = ">=3.8"
files = [
- {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"},
- {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
+ {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"},
+ {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"},
]
[[package]]
@@ -2348,7 +2491,7 @@ dev = ["aafigure", "matplotlib", "pygame", "pyglet (<2.0.0)", "sphinx", "wheel"]
name = "pyopengl"
version = "3.1.7"
description = "Standard OpenGL bindings for Python"
-optional = false
+optional = true
python-versions = "*"
files = [
{file = "PyOpenGL-3.1.7-py3-none-any.whl", hash = "sha256:a6ab19cf290df6101aaf7470843a9c46207789855746399d0af92521a0a92b7a"},
@@ -2359,7 +2502,7 @@ files = [
name = "pyparsing"
version = "3.1.2"
description = "pyparsing module - Classes and methods to define and execute parsing grammars"
-optional = false
+optional = true
python-versions = ">=3.6.8"
files = [
{file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"},
@@ -2790,7 +2933,7 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"]
name = "scikit-image"
version = "0.22.0"
description = "Image processing in Python"
-optional = false
+optional = true
python-versions = ">=3.9"
files = [
{file = "scikit_image-0.22.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:74ec5c1d4693506842cc7c9487c89d8fc32aed064e9363def7af08b8f8cbb31d"},
@@ -2836,55 +2979,55 @@ test = ["asv", "matplotlib (>=3.5)", "numpydoc (>=1.5)", "pooch (>=1.6.0)", "pyt
[[package]]
name = "scipy"
-version = "1.12.0"
+version = "1.13.0"
description = "Fundamental algorithms for scientific computing in Python"
-optional = false
+optional = true
python-versions = ">=3.9"
files = [
- {file = "scipy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:78e4402e140879387187f7f25d91cc592b3501a2e51dfb320f48dfb73565f10b"},
- {file = "scipy-1.12.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f5f00ebaf8de24d14b8449981a2842d404152774c1a1d880c901bf454cb8e2a1"},
- {file = "scipy-1.12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e53958531a7c695ff66c2e7bb7b79560ffdc562e2051644c5576c39ff8efb563"},
- {file = "scipy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e32847e08da8d895ce09d108a494d9eb78974cf6de23063f93306a3e419960c"},
- {file = "scipy-1.12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4c1020cad92772bf44b8e4cdabc1df5d87376cb219742549ef69fc9fd86282dd"},
- {file = "scipy-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:75ea2a144096b5e39402e2ff53a36fecfd3b960d786b7efd3c180e29c39e53f2"},
- {file = "scipy-1.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:408c68423f9de16cb9e602528be4ce0d6312b05001f3de61fe9ec8b1263cad08"},
- {file = "scipy-1.12.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5adfad5dbf0163397beb4aca679187d24aec085343755fcdbdeb32b3679f254c"},
- {file = "scipy-1.12.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3003652496f6e7c387b1cf63f4bb720951cfa18907e998ea551e6de51a04467"},
- {file = "scipy-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b8066bce124ee5531d12a74b617d9ac0ea59245246410e19bca549656d9a40a"},
- {file = "scipy-1.12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8bee4993817e204d761dba10dbab0774ba5a8612e57e81319ea04d84945375ba"},
- {file = "scipy-1.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:a24024d45ce9a675c1fb8494e8e5244efea1c7a09c60beb1eeb80373d0fecc70"},
- {file = "scipy-1.12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e7e76cc48638228212c747ada851ef355c2bb5e7f939e10952bc504c11f4e372"},
- {file = "scipy-1.12.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f7ce148dffcd64ade37b2df9315541f9adad6efcaa86866ee7dd5db0c8f041c3"},
- {file = "scipy-1.12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c39f92041f490422924dfdb782527a4abddf4707616e07b021de33467f917bc"},
- {file = "scipy-1.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a7ebda398f86e56178c2fa94cad15bf457a218a54a35c2a7b4490b9f9cb2676c"},
- {file = "scipy-1.12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:95e5c750d55cf518c398a8240571b0e0782c2d5a703250872f36eaf737751338"},
- {file = "scipy-1.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:e646d8571804a304e1da01040d21577685ce8e2db08ac58e543eaca063453e1c"},
- {file = "scipy-1.12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:913d6e7956c3a671de3b05ccb66b11bc293f56bfdef040583a7221d9e22a2e35"},
- {file = "scipy-1.12.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba1b0c7256ad75401c73e4b3cf09d1f176e9bd4248f0d3112170fb2ec4db067"},
- {file = "scipy-1.12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:730badef9b827b368f351eacae2e82da414e13cf8bd5051b4bdfd720271a5371"},
- {file = "scipy-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6546dc2c11a9df6926afcbdd8a3edec28566e4e785b915e849348c6dd9f3f490"},
- {file = "scipy-1.12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:196ebad3a4882081f62a5bf4aeb7326aa34b110e533aab23e4374fcccb0890dc"},
- {file = "scipy-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:b360f1b6b2f742781299514e99ff560d1fe9bd1bff2712894b52abe528d1fd1e"},
- {file = "scipy-1.12.0.tar.gz", hash = "sha256:4bf5abab8a36d20193c698b0f1fc282c1d083c94723902c447e5d2f1780936a3"},
+ {file = "scipy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d"},
+ {file = "scipy-1.13.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e"},
+ {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922"},
+ {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4"},
+ {file = "scipy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9"},
+ {file = "scipy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd"},
+ {file = "scipy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa"},
+ {file = "scipy-1.13.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5"},
+ {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7"},
+ {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d"},
+ {file = "scipy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c"},
+ {file = "scipy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6"},
+ {file = "scipy-1.13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b"},
+ {file = "scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551"},
+ {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a"},
+ {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42"},
+ {file = "scipy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820"},
+ {file = "scipy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21"},
+ {file = "scipy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5407708195cb38d70fd2d6bb04b1b9dd5c92297d86e9f9daae1576bd9e06f602"},
+ {file = "scipy-1.13.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:ac38c4c92951ac0f729c4c48c9e13eb3675d9986cc0c83943784d7390d540c78"},
+ {file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09c74543c4fbeb67af6ce457f6a6a28e5d3739a87f62412e4a16e46f164f0ae5"},
+ {file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28e286bf9ac422d6beb559bc61312c348ca9b0f0dae0d7c5afde7f722d6ea13d"},
+ {file = "scipy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33fde20efc380bd23a78a4d26d59fc8704e9b5fd9b08841693eb46716ba13d86"},
+ {file = "scipy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:45c08bec71d3546d606989ba6e7daa6f0992918171e2a6f7fbedfa7361c2de1e"},
+ {file = "scipy-1.13.0.tar.gz", hash = "sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e"},
]
[package.dependencies]
-numpy = ">=1.22.4,<1.29.0"
+numpy = ">=1.22.4,<2.3"
[package.extras]
-dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"]
-doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"]
-test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
+dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"]
+doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"]
+test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
[[package]]
name = "sentry-sdk"
-version = "1.43.0"
+version = "1.44.1"
description = "Python client for Sentry (https://sentry.io)"
optional = false
python-versions = "*"
files = [
- {file = "sentry-sdk-1.43.0.tar.gz", hash = "sha256:41df73af89d22921d8733714fb0fc5586c3461907e06688e6537d01a27e0e0f6"},
- {file = "sentry_sdk-1.43.0-py2.py3-none-any.whl", hash = "sha256:8d768724839ca18d7b4c7463ef7528c40b7aa2bfbf7fe554d5f9a7c044acfd36"},
+ {file = "sentry-sdk-1.44.1.tar.gz", hash = "sha256:24e6a53eeabffd2f95d952aa35ca52f0f4201d17f820ac9d3ff7244c665aaf68"},
+ {file = "sentry_sdk-1.44.1-py2.py3-none-any.whl", hash = "sha256:5f75eb91d8ab6037c754a87b8501cc581b2827e923682f593bed3539ce5b3999"},
]
[package.dependencies]
@@ -3043,7 +3186,7 @@ testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jar
name = "shapely"
version = "2.0.3"
description = "Manipulation and analysis of geometric objects"
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "shapely-2.0.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:af7e9abe180b189431b0f490638281b43b84a33a960620e6b2e8d3e3458b61a1"},
@@ -3192,31 +3335,6 @@ numpy = "*"
packaging = "*"
protobuf = ">=3.20"
-[[package]]
-name = "tensordict"
-version = "0.4.0+b4c91e8"
-description = ""
-optional = false
-python-versions = "*"
-files = []
-develop = false
-
-[package.dependencies]
-cloudpickle = "*"
-numpy = "*"
-torch = ">=2.1.0"
-
-[package.extras]
-checkpointing = ["torchsnapshot-nightly"]
-h5 = ["h5py (>=3.8)"]
-tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures", "pyyaml"]
-
-[package.source]
-type = "git"
-url = "https://github.com/pytorch/tensordict"
-reference = "HEAD"
-resolved_reference = "b4c91e8828c538ca0a50d8383fd99311a9afb078"
-
[[package]]
name = "termcolor"
version = "2.4.0"
@@ -3235,7 +3353,7 @@ tests = ["pytest", "pytest-cov"]
name = "tifffile"
version = "2024.2.12"
description = "Read and write TIFF files"
-optional = false
+optional = true
python-versions = ">=3.9"
files = [
{file = "tifffile-2024.2.12-py3-none-any.whl", hash = "sha256:870998f82fbc94ff7c3528884c1b0ae54863504ff51dbebea431ac3fa8fb7c21"},
@@ -3261,36 +3379,36 @@ files = [
[[package]]
name = "torch"
-version = "2.2.1"
+version = "2.2.2"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "torch-2.2.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:8d3bad336dd2c93c6bcb3268e8e9876185bda50ebde325ef211fb565c7d15273"},
- {file = "torch-2.2.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5297f13370fdaca05959134b26a06a7f232ae254bf2e11a50eddec62525c9006"},
- {file = "torch-2.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:5f5dee8433798888ca1415055f5e3faf28a3bad660e4c29e1014acd3275ab11a"},
- {file = "torch-2.2.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:b6d78338acabf1fb2e88bf4559d837d30230cf9c3e4337261f4d83200df1fcbe"},
- {file = "torch-2.2.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:6ab3ea2e29d1aac962e905142bbe50943758f55292f1b4fdfb6f4792aae3323e"},
- {file = "torch-2.2.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:d86664ec85902967d902e78272e97d1aff1d331f7619d398d3ffab1c9b8e9157"},
- {file = "torch-2.2.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d6227060f268894f92c61af0a44c0d8212e19cb98d05c20141c73312d923bc0a"},
- {file = "torch-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:77e990af75fb1675490deb374d36e726f84732cd5677d16f19124934b2409ce9"},
- {file = "torch-2.2.1-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:46085e328d9b738c261f470231e987930f4cc9472d9ffb7087c7a1343826ac51"},
- {file = "torch-2.2.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:2d9e7e5ecbb002257cf98fae13003abbd620196c35f85c9e34c2adfb961321ec"},
- {file = "torch-2.2.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:ada53aebede1c89570e56861b08d12ba4518a1f8b82d467c32665ec4d1f4b3c8"},
- {file = "torch-2.2.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:be21d4c41ecebed9e99430dac87de1439a8c7882faf23bba7fea3fea7b906ac1"},
- {file = "torch-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:79848f46196750367dcdf1d2132b722180b9d889571e14d579ae82d2f50596c5"},
- {file = "torch-2.2.1-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:7ee804847be6be0032fbd2d1e6742fea2814c92bebccb177f0d3b8e92b2d2b18"},
- {file = "torch-2.2.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:84b2fb322ab091039fdfe74e17442ff046b258eb5e513a28093152c5b07325a7"},
- {file = "torch-2.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5c0c83aa7d94569997f1f474595e808072d80b04d34912ce6f1a0e1c24b0c12a"},
- {file = "torch-2.2.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:91a1b598055ba06b2c386415d2e7f6ac818545e94c5def597a74754940188513"},
- {file = "torch-2.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:8f93ddf3001ecec16568390b507652644a3a103baa72de3ad3b9c530e3277098"},
- {file = "torch-2.2.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:0e8bdd4c77ac2584f33ee14c6cd3b12767b4da508ec4eed109520be7212d1069"},
- {file = "torch-2.2.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:6a21bcd7076677c97ca7db7506d683e4e9db137e8420eb4a68fb67c3668232a7"},
- {file = "torch-2.2.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:f1b90ac61f862634039265cd0f746cc9879feee03ff962c803486301b778714b"},
- {file = "torch-2.2.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:ed9e29eb94cd493b36bca9cb0b1fd7f06a0688215ad1e4b3ab4931726e0ec092"},
- {file = "torch-2.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:c47bc25744c743f3835831a20efdcfd60aeb7c3f9804a213f61e45803d16c2a5"},
- {file = "torch-2.2.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:0952549bcb43448c8d860d5e3e947dd18cbab491b14638e21750cb3090d5ad3e"},
- {file = "torch-2.2.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:26bd2272ec46fc62dcf7d24b2fb284d44fcb7be9d529ebf336b9860350d674ed"},
+ {file = "torch-2.2.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bc889d311a855dd2dfd164daf8cc903a6b7273a747189cebafdd89106e4ad585"},
+ {file = "torch-2.2.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:15dffa4cc3261fa73d02f0ed25f5fa49ecc9e12bf1ae0a4c1e7a88bbfaad9030"},
+ {file = "torch-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:11e8fe261233aeabd67696d6b993eeb0896faa175c6b41b9a6c9f0334bdad1c5"},
+ {file = "torch-2.2.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:b2e2200b245bd9f263a0d41b6a2dab69c4aca635a01b30cca78064b0ef5b109e"},
+ {file = "torch-2.2.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:877b3e6593b5e00b35bbe111b7057464e76a7dd186a287280d941b564b0563c2"},
+ {file = "torch-2.2.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:ad4c03b786e074f46606f4151c0a1e3740268bcf29fbd2fdf6666d66341c1dcb"},
+ {file = "torch-2.2.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:32827fa1fbe5da8851686256b4cd94cc7b11be962862c2293811c94eea9457bf"},
+ {file = "torch-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:f9ef0a648310435511e76905f9b89612e45ef2c8b023bee294f5e6f7e73a3e7c"},
+ {file = "torch-2.2.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:95b9b44f3bcebd8b6cd8d37ec802048c872d9c567ba52c894bba90863a439059"},
+ {file = "torch-2.2.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:49aa4126ede714c5aeef7ae92969b4b0bbe67f19665106463c39f22e0a1860d1"},
+ {file = "torch-2.2.2-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:cf12cdb66c9c940227ad647bc9cf5dba7e8640772ae10dfe7569a0c1e2a28aca"},
+ {file = "torch-2.2.2-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:89ddac2a8c1fb6569b90890955de0c34e1724f87431cacff4c1979b5f769203c"},
+ {file = "torch-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:451331406b760f4b1ab298ddd536486ab3cfb1312614cfe0532133535be60bea"},
+ {file = "torch-2.2.2-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:eb4d6e9d3663e26cd27dc3ad266b34445a16b54908e74725adb241aa56987533"},
+ {file = "torch-2.2.2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:bf9558da7d2bf7463390b3b2a61a6a3dbb0b45b161ee1dd5ec640bf579d479fc"},
+ {file = "torch-2.2.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cd2bf7697c9e95fb5d97cc1d525486d8cf11a084c6af1345c2c2c22a6b0029d0"},
+ {file = "torch-2.2.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:b421448d194496e1114d87a8b8d6506bce949544e513742b097e2ab8f7efef32"},
+ {file = "torch-2.2.2-cp38-cp38-win_amd64.whl", hash = "sha256:3dbcd563a9b792161640c0cffe17e3270d85e8f4243b1f1ed19cca43d28d235b"},
+ {file = "torch-2.2.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:31f4310210e7dda49f1fb52b0ec9e59382cfcb938693f6d5378f25b43d7c1d29"},
+ {file = "torch-2.2.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:c795feb7e8ce2e0ef63f75f8e1ab52e7fd5e1a4d7d0c31367ade1e3de35c9e95"},
+ {file = "torch-2.2.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a6e5770d68158d07456bfcb5318b173886f579fdfbf747543901ce718ea94782"},
+ {file = "torch-2.2.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:67dcd726edff108e2cd6c51ff0e416fd260c869904de95750e80051358680d24"},
+ {file = "torch-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:539d5ef6c4ce15bd3bd47a7b4a6e7c10d49d4d21c0baaa87c7d2ef8698632dfb"},
+ {file = "torch-2.2.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:dff696de90d6f6d1e8200e9892861fd4677306d0ef604cb18f2134186f719f82"},
+ {file = "torch-2.2.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:3a4dd910663fd7a124c056c878a52c2b0be4a5a424188058fe97109d4436ee42"},
]
[package.dependencies]
@@ -3317,78 +3435,44 @@ typing-extensions = ">=4.8.0"
opt-einsum = ["opt-einsum (>=3.3)"]
optree = ["optree (>=0.9.1)"]
-[[package]]
-name = "torchrl"
-version = "0.4.0+13bef42"
-description = ""
-optional = false
-python-versions = "*"
-files = []
-develop = false
-
-[package.dependencies]
-cloudpickle = "*"
-numpy = "*"
-packaging = "*"
-tensordict = ">=0.4.0"
-torch = ">=2.1.0"
-
-[package.extras]
-all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
-atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"]
-checkpointing = ["torchsnapshot"]
-dm-control = ["dm_control"]
-gym-continuous = ["gymnasium", "mujoco"]
-marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"]
-offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
-rendering = ["moviepy"]
-tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"]
-utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"]
-
-[package.source]
-type = "git"
-url = "https://github.com/pytorch/rl"
-reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
-resolved_reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37"
-
[[package]]
name = "torchvision"
-version = "0.17.1"
+version = "0.17.2"
description = "image and video datasets and models for torch deep learning"
optional = false
python-versions = ">=3.8"
files = [
- {file = "torchvision-0.17.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:06418880212b66e45e855dd39f536e7fd48b4e6b034a11dd9fe9e2384afb51ec"},
- {file = "torchvision-0.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:33d65d0c7fdcb3f7bc1dd8ed30ea3cd7e0587b4ad1b104b5677c8191a8bad9f1"},
- {file = "torchvision-0.17.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:aaefef2be6a02f206085ce4bb6c0078b03ebf48cb6ff82bd762ff6248475e08e"},
- {file = "torchvision-0.17.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ebe5fdb466aff8a8e8e755de84a843418b6f8d500624752c05eaa638d7700f3d"},
- {file = "torchvision-0.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:9d4d45a996f4313e9c5db4da71d31508d44f7ccfbf29d3442bdcc2ad13e0b6f3"},
- {file = "torchvision-0.17.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:ea2ccdbf5974e0bf27fd6644a33b19cb0700297cf397bb0469e762c11c6c4105"},
- {file = "torchvision-0.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9106e32c9f1e70afa8172cf1b064cf9c2998d8dff0769ec69d537b20209ee43d"},
- {file = "torchvision-0.17.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:5966936c669a08870f6547cd0a90d08b157aeda03293f79e2adbb934687175ed"},
- {file = "torchvision-0.17.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e74f5a26ef8190eab0c38b3f63914fea94e58e3b2f0e5466611c9f63bd91a80b"},
- {file = "torchvision-0.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:a2109c1a1dcf71e8940d43e91f78c4dd5bf0fcefb3a0a42244102752009f5862"},
- {file = "torchvision-0.17.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5d241d2a5fb4e608677fccf6f80b34a124446d324ee40c7814ce54bce888275b"},
- {file = "torchvision-0.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e0fe98d9d92c23d2262ff82f973242951b9357fb640f8888ac50848bd00f5b45"},
- {file = "torchvision-0.17.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:32dc5de86d2ade399e11087095674ca08a1649fb322cfe69336d28add467edcb"},
- {file = "torchvision-0.17.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:54902877410ffb5458ee52b6d0de4b25cf01496bee736d6825301a5f0398536e"},
- {file = "torchvision-0.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:cc22c1ed0f1aba3f98fd72b6f60021f57aec1d2f6af518522e8a0a83848de3a8"},
- {file = "torchvision-0.17.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:2621097065fa1c827885e2b52102e839a3541b933b7a90e0fa3c42c3de1bc3cf"},
- {file = "torchvision-0.17.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5ce76466af2b5a30573939cae1e6e62e29316ceb3ee748091002f312ab0912f6"},
- {file = "torchvision-0.17.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:bd5dcd14a32945c72f5c19341add94aa7c23dd7bca2bafde44d0f3c4344d17ed"},
- {file = "torchvision-0.17.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:dca22795cc02ca0d5ddc08c1422ff620bc9899f63d15dc36f71ef37250e17b75"},
- {file = "torchvision-0.17.1-cp38-cp38-win_amd64.whl", hash = "sha256:524405457dd97d9ab0e48df502f819d0f41a113ce8f00470bb9926d9d36efcf1"},
- {file = "torchvision-0.17.1-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:58299a724b37b893c7ce4d0b32ea1480c30e467cc114167964b45f6013f6c2d3"},
- {file = "torchvision-0.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8a1b17fb158b2b881f2c8796fe1839a624e49d5fd07aa61f6dae60ba4819421a"},
- {file = "torchvision-0.17.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:429d63eb7551aa4d8f6cdf08d109b5570c20cbcce36d9cb95b24556418e4dc82"},
- {file = "torchvision-0.17.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:0ecc9a58171bd555aed583bf2f72e7fd6cc4f767c14f8b80b6a8725eacf4ceb1"},
- {file = "torchvision-0.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:5f427ebee15521edcd836bfe05e86feb5189b5c943b9e3999ed0e3f391fbaa1d"},
+ {file = "torchvision-0.17.2-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:1f2910fe3c21ad6875b2720d46fad835b2e4b336e9553d31ca364d24c90b1d4f"},
+ {file = "torchvision-0.17.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ecc1c503fa8a54fbab777e06a7c228032b8ab78efebf35b28bc8f22f544f51f1"},
+ {file = "torchvision-0.17.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:f400145fc108833e7c2fc28486a04989ca742146d7a2a2cc48878ebbb40cdbbd"},
+ {file = "torchvision-0.17.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:e9e4bed404af33dfc92eecc2b513d21ddc4c242a7fd8708b3b09d3a26aa6f444"},
+ {file = "torchvision-0.17.2-cp310-cp310-win_amd64.whl", hash = "sha256:ba2e62f233eab3d42b648c122a3a29c47cc108ca314dfd5cbb59cd3a143fd623"},
+ {file = "torchvision-0.17.2-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:9b83e55ee7d0a1704f52b9c0ac87388e7a6d1d98a6bde7b0b35f9ab54d7bda54"},
+ {file = "torchvision-0.17.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e031004a1bc432c980a7bd642f6c189a3efc316e423fc30b5569837166a4e28d"},
+ {file = "torchvision-0.17.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:3bbc24b7713e8f22766992562547d8b4b10001208d372fe599255af84bfd1a69"},
+ {file = "torchvision-0.17.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:833fd2e4216ced924c8aca0525733fe727f9a1af66dfad7c5be7257e97c39678"},
+ {file = "torchvision-0.17.2-cp311-cp311-win_amd64.whl", hash = "sha256:6835897df852fad1015e6a106c167c83848114cbcc7d86112384a973404e4431"},
+ {file = "torchvision-0.17.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:14fd1d4a033c325bdba2d03a69c3450cab6d3a625f85cc375781d9237ca5d04d"},
+ {file = "torchvision-0.17.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9c3acbebbe379af112b62b535820174277b1f3eed30df264a4e458d58ee4e5b2"},
+ {file = "torchvision-0.17.2-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:77d680adf6ce367166a186d2c7fda3a73807ab9a03b2c31a03fa8812c8c5335b"},
+ {file = "torchvision-0.17.2-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:f1c9ab3152cfb27f83aca072cac93a3a4c4e4ab0261cf0f2d516b9868a4e96f3"},
+ {file = "torchvision-0.17.2-cp312-cp312-win_amd64.whl", hash = "sha256:3f784381419f3ed3f2ec2aa42fb4aeec5bf4135e298d1631e41c926e6f1a0dff"},
+ {file = "torchvision-0.17.2-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:b83aac8d78f48981146d582168d75b6c947cfb0a7693f76e219f1926f6e595a3"},
+ {file = "torchvision-0.17.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1ece40557e122d79975860a005aa7e2a9e2e6c350a03e78a00ec1450083312fd"},
+ {file = "torchvision-0.17.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:32dbeba3987e20f2dc1bce8d1504139fff582898346dfe8ad98d649f97ca78fa"},
+ {file = "torchvision-0.17.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:35ba5c1600c3203549d2316422a659bd20c0cfda1b6085eec94fb9f35f55ca43"},
+ {file = "torchvision-0.17.2-cp38-cp38-win_amd64.whl", hash = "sha256:2f69570f50b1d195e51bc03feffb7b7728207bc36efcfb1f0813712b2379d881"},
+ {file = "torchvision-0.17.2-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:4868bbfa55758c8107e69a0e7dd5e77b89056035cd38b767ad5b98cdb71c0f0d"},
+ {file = "torchvision-0.17.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:efd6d0dd0668e15d01a2cffadc74068433b32cbcf5692e0c4aa15fc5cb250ce7"},
+ {file = "torchvision-0.17.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:7dc85b397f6c6d9ef12716ce0d6e11ac2b803f5cccff6fe3966db248e7774478"},
+ {file = "torchvision-0.17.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d506854c5acd69b20a8b6641f01fe841685a21c5406b56813184f1c9fc94279e"},
+ {file = "torchvision-0.17.2-cp39-cp39-win_amd64.whl", hash = "sha256:067095e87a020a7a251ac1d38483aa591c5ccb81e815527c54db88a982fc9267"},
]
[package.dependencies]
numpy = "*"
pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0"
-torch = "2.2.1"
+torch = "2.2.2"
[package.extras]
scipy = ["scipy"]
@@ -3438,13 +3522,13 @@ tutorials = ["matplotlib", "pandas", "tabulate", "torch"]
[[package]]
name = "typing-extensions"
-version = "4.10.0"
+version = "4.11.0"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
files = [
- {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"},
- {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"},
+ {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
+ {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
]
[[package]]
@@ -3497,13 +3581,13 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess
[[package]]
name = "wandb"
-version = "0.16.4"
+version = "0.16.6"
description = "A CLI and library for interacting with the Weights & Biases API."
optional = false
python-versions = ">=3.7"
files = [
- {file = "wandb-0.16.4-py3-none-any.whl", hash = "sha256:bb9eb5aa2c2c85e11c76040c4271366f54d4975167aa6320ba86c3f2d97fe5fa"},
- {file = "wandb-0.16.4.tar.gz", hash = "sha256:8752c67d1347a4c29777e64dc1e1a742a66c5ecde03aebadf2b0d62183fa307c"},
+ {file = "wandb-0.16.6-py3-none-any.whl", hash = "sha256:5810019a3b981c796e98ea58557a7c380f18834e0c6bdaed15df115522e5616e"},
+ {file = "wandb-0.16.6.tar.gz", hash = "sha256:86f491e3012d715e0d7d7421a4d6de41abef643b7403046261f962f3e512fe1c"},
]
[package.dependencies]
@@ -3535,13 +3619,13 @@ sweeps = ["sweeps (>=0.2.0)"]
[[package]]
name = "werkzeug"
-version = "3.0.1"
+version = "3.0.2"
description = "The comprehensive WSGI web application library."
optional = false
python-versions = ">=3.8"
files = [
- {file = "werkzeug-3.0.1-py3-none-any.whl", hash = "sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10"},
- {file = "werkzeug-3.0.1.tar.gz", hash = "sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc"},
+ {file = "werkzeug-3.0.2-py3-none-any.whl", hash = "sha256:3aac3f5da756f93030740bc235d3e09449efcf65f2f55e3602e1d851b8f48795"},
+ {file = "werkzeug-3.0.2.tar.gz", hash = "sha256:e39b645a6ac92822588e7b39a692e7828724ceae0b0d702ef96701f90e70128d"},
]
[package.dependencies]
@@ -3552,20 +3636,20 @@ watchdog = ["watchdog (>=2.3)"]
[[package]]
name = "zarr"
-version = "2.17.1"
+version = "2.17.2"
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
optional = false
python-versions = ">=3.9"
files = [
- {file = "zarr-2.17.1-py3-none-any.whl", hash = "sha256:e25df2741a6e92645f3890f30f3136d5b57a0f8f831094b024bbcab5f2797bc7"},
- {file = "zarr-2.17.1.tar.gz", hash = "sha256:564b3aa072122546fe69a0fa21736f466b20fad41754334b62619f088ce46261"},
+ {file = "zarr-2.17.2-py3-none-any.whl", hash = "sha256:70d7cc07c24280c380ef80644151d136b7503b0d83c9f214e8000ddc0f57f69b"},
+ {file = "zarr-2.17.2.tar.gz", hash = "sha256:2cbaa6cb4e342d45152d4a7a4b2013c337fcd3a8e7bc98253560180de60552ce"},
]
[package.dependencies]
asciitree = "*"
fasteners = {version = "*", markers = "sys_platform != \"emscripten\""}
numcodecs = ">=0.10.0"
-numpy = ">=1.21.1"
+numpy = ">=1.23"
[package.extras]
docs = ["numcodecs[msgpack]", "numpydoc", "pydata-sphinx-theme", "sphinx", "sphinx-automodapi", "sphinx-copybutton", "sphinx-design", "sphinx-issues"]
@@ -3586,7 +3670,12 @@ files = [
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
+[extras]
+aloha = ["gym-aloha"]
+pusht = ["gym-pusht"]
+xarm = ["gym-xarm"]
+
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "174c7d42f8039eedd2c447a4e6cae5169782cbd94346b5606572a0010194ca05"
+content-hash = "7ec0310f8dd0ffa4d92fa78e06513bce98c3657692b3753ff34aadd297a3766c"
diff --git a/pyproject.toml b/pyproject.toml
index 972c1b61..743dece8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,7 +23,6 @@ packages = [{include = "lerobot"}]
python = "^3.10"
termcolor = "^2.4.0"
omegaconf = "^2.3.0"
-dm-env = "^1.6"
pandas = "^2.2.1"
wandb = "^0.16.3"
moviepy = "^1.0.3"
@@ -34,29 +33,40 @@ einops = "^0.7.0"
pygame = "^2.5.2"
pymunk = "^6.6.0"
zarr = "^2.17.0"
-shapely = "^2.0.3"
-scikit-image = "^0.22.0"
numba = "^0.59.0"
mpmath = "^1.3.0"
torch = "^2.2.1"
-tensordict = {git = "https://github.com/pytorch/tensordict"}
-torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
-mujoco = "^2.3.7"
opencv-python = "^4.9.0.80"
diffusers = "^0.26.3"
torchvision = "^0.17.1"
h5py = "^3.10.0"
-dm-control = "1.0.14"
huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"}
robomimic = "0.2.0"
-gymnasium-robotics = "^1.2.4"
gymnasium = "^0.29.1"
cmake = "^3.29.0.1"
+gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
+gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
+gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true}
+# gym-pusht = { path = "../gym-pusht", develop = true, optional = true}
+# gym-xarm = { path = "../gym-xarm", develop = true, optional = true}
+# gym-aloha = { path = "../gym-aloha", develop = true, optional = true}
+
+[tool.poetry.extras]
+pusht = ["gym-pusht"]
+xarm = ["gym-xarm"]
+aloha = ["gym-aloha"]
+
+
+[tool.poetry.group.dev]
+optional = true
[tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2"
debugpy = "^1.8.1"
+
+
+[tool.poetry.group.test.dependencies]
pytest = "^8.1.0"
pytest-cov = "^5.0.0"
@@ -100,3 +110,6 @@ enable = true
[build-system]
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
build-backend = "poetry_dynamic_versioning.backend"
+
+[tool.black]
+line-length = 110
diff --git a/sbatch.sh b/sbatch.sh
deleted file mode 100644
index cb5b285a..00000000
--- a/sbatch.sh
+++ /dev/null
@@ -1,23 +0,0 @@
-#!/bin/bash
-#SBATCH --nodes=1 # total number of nodes (N to be defined)
-#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU)
-#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs)
-#SBATCH --cpus-per-task=8 # number of cores per task (8x8 = 64 cores, or all the cores)
-#SBATCH --time=2-00:00:00
-#SBATCH --output=/home/rcadene/slurm/%j.out
-#SBATCH --error=/home/rcadene/slurm/%j.err
-#SBATCH --qos=medium
-#SBATCH --mail-user=re.cadene@gmail.com
-#SBATCH --mail-type=ALL
-
-CMD=$@
-echo "command: $CMD"
-
-apptainer exec --nv \
-~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL
-
-source ~/.bashrc
-#conda activate fowm
-conda activate lerobot
-
-srun $CMD
diff --git a/sbatch_hopper.sh b/sbatch_hopper.sh
deleted file mode 100644
index cc410048..00000000
--- a/sbatch_hopper.sh
+++ /dev/null
@@ -1,17 +0,0 @@
-#!/bin/bash
-#SBATCH --nodes=1 # total number of nodes (N to be defined)
-#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU)
-#SBATCH --qos=normal # number of GPUs reserved per node (here 8, or all the GPUs)
-#SBATCH --partition=hopper-prod
-#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs)
-#SBATCH --cpus-per-task=12 # number of cores per task
-#SBATCH --mem-per-cpu=11G
-#SBATCH --time=12:00:00
-#SBATCH --output=/admin/home/remi_cadene/slurm/%j.out
-#SBATCH --error=/admin/home/remi_cadene/slurm/%j.err
-#SBATCH --mail-user=remi_cadene@huggingface.co
-#SBATCH --mail-type=ALL
-
-CMD=$@
-echo "command: $CMD"
-srun $CMD
diff --git a/tests/data/aloha_sim_insertion_human/data_dict.pth b/tests/data/aloha_sim_insertion_human/data_dict.pth
new file mode 100644
index 00000000..1370c9ea
Binary files /dev/null and b/tests/data/aloha_sim_insertion_human/data_dict.pth differ
diff --git a/tests/data/aloha_sim_insertion_human/data_ids_per_episode.pth b/tests/data/aloha_sim_insertion_human/data_ids_per_episode.pth
new file mode 100644
index 00000000..a1d481dd
Binary files /dev/null and b/tests/data/aloha_sim_insertion_human/data_ids_per_episode.pth differ
diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/action.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/action.memmap
deleted file mode 100644
index f64b2989..00000000
--- a/tests/data/aloha_sim_insertion_human/replay_buffer/action.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:d789deddb081a9f4b626342391de8f48949d38fb5fdead87b5c0737b46c0877a
-size 2800
diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/episode.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/episode.memmap
deleted file mode 100644
index af9fb07f..00000000
--- a/tests/data/aloha_sim_insertion_human/replay_buffer/episode.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
-size 400
diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/frame_id.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/frame_id.memmap
deleted file mode 100644
index dc2f585c..00000000
--- a/tests/data/aloha_sim_insertion_human/replay_buffer/frame_id.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
-size 400
diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/meta.json
deleted file mode 100644
index 2a0cf0a2..00000000
--- a/tests/data/aloha_sim_insertion_human/replay_buffer/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/next/done.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/next/done.memmap
deleted file mode 100644
index 44fd709f..00000000
--- a/tests/data/aloha_sim_insertion_human/replay_buffer/next/done.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
-size 50
diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/next/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/next/meta.json
deleted file mode 100644
index 3bfa9bd7..00000000
--- a/tests/data/aloha_sim_insertion_human/replay_buffer/next/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/meta.json
deleted file mode 100644
index cb29a5ab..00000000
--- a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/top.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/top.memmap
deleted file mode 100644
index d3d8bd1c..00000000
--- a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/top.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:5c632e3cb06be729e5d673e3ecca1d6f6527b0f48cfe3dc03d7eea4f9eb3bbd7
-size 46080000
diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/meta.json
deleted file mode 100644
index 65ce1ca2..00000000
--- a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/state.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/state.memmap
deleted file mode 100644
index 1f087a60..00000000
--- a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/state.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:e231f2e07e1cd030137ea2e938b570b112db2c694c6d21b37ceb8f8559e19088
-size 2800
diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/meta.json
deleted file mode 100644
index cb29a5ab..00000000
--- a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/top.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/top.memmap
deleted file mode 100644
index 00c0b783..00000000
--- a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/top.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:a1ba64c89f4fcf9135fe34c26abf582dd5f0d573506db5c96af3ffe40a52c818
-size 46080000
diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/meta.json
deleted file mode 100644
index 65ce1ca2..00000000
--- a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/state.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/state.memmap
deleted file mode 100644
index a1131179..00000000
--- a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/state.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:85405686bc065c6ab6c915907920a0391a57cf097b74de058a8c30be0548ade5
-size 2800
diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth
index 87b18a24..a7b9248f 100644
Binary files a/tests/data/aloha_sim_insertion_human/stats.pth and b/tests/data/aloha_sim_insertion_human/stats.pth differ
diff --git a/tests/data/aloha_sim_insertion_scripted/data_dict.pth b/tests/data/aloha_sim_insertion_scripted/data_dict.pth
new file mode 100644
index 00000000..00c9f335
Binary files /dev/null and b/tests/data/aloha_sim_insertion_scripted/data_dict.pth differ
diff --git a/tests/data/aloha_sim_insertion_scripted/data_ids_per_episode.pth b/tests/data/aloha_sim_insertion_scripted/data_ids_per_episode.pth
new file mode 100644
index 00000000..a1d481dd
Binary files /dev/null and b/tests/data/aloha_sim_insertion_scripted/data_ids_per_episode.pth differ
diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/action.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/action.memmap
deleted file mode 100644
index e4068b75..00000000
--- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/action.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:1f5fe053b760e8471885b82c10f4a6ea40874098036337ae5cc300c4775546be
-size 2800
diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/episode.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/episode.memmap
deleted file mode 100644
index af9fb07f..00000000
--- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/episode.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
-size 400
diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/frame_id.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/frame_id.memmap
deleted file mode 100644
index dc2f585c..00000000
--- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/frame_id.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
-size 400
diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/meta.json
deleted file mode 100644
index 2a0cf0a2..00000000
--- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/done.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/done.memmap
deleted file mode 100644
index 44fd709f..00000000
--- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/done.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
-size 50
diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/meta.json
deleted file mode 100644
index 3bfa9bd7..00000000
--- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/meta.json
deleted file mode 100644
index cb29a5ab..00000000
--- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/top.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/top.memmap
deleted file mode 100644
index 83911729..00000000
--- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/top.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:daed2bb10498ba2557983d0d7e89399882fea7585e7ceff910e23c621bfdbf88
-size 46080000
diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/meta.json
deleted file mode 100644
index 65ce1ca2..00000000
--- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/state.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/state.memmap
deleted file mode 100644
index aef69da0..00000000
--- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/state.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:bbad0302af70112ee312efe0eb0f44a2f1c8f6c5ef82ea4fb34625cdafbef057
-size 2800
diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/meta.json
deleted file mode 100644
index cb29a5ab..00000000
--- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/top.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/top.memmap
deleted file mode 100644
index f9f0a759..00000000
--- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/top.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:aba55ebb9dd004bf68444b9ebf024ed7713436099c06a0b8e541100ecbc69290
-size 46080000
diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/meta.json
deleted file mode 100644
index 65ce1ca2..00000000
--- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/state.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/state.memmap
deleted file mode 100644
index 91875055..00000000
--- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/state.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:dd4e7e14abf57561ca9839c910581266be90956e41bfb3bb21362ea0c321e77d
-size 2800
diff --git a/tests/data/aloha_sim_insertion_scripted/stats.pth b/tests/data/aloha_sim_insertion_scripted/stats.pth
index 7d149ca4..990d4647 100644
Binary files a/tests/data/aloha_sim_insertion_scripted/stats.pth and b/tests/data/aloha_sim_insertion_scripted/stats.pth differ
diff --git a/tests/data/aloha_sim_transfer_cube_human/data_dict.pth b/tests/data/aloha_sim_transfer_cube_human/data_dict.pth
new file mode 100644
index 00000000..ab851779
Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_human/data_dict.pth differ
diff --git a/tests/data/aloha_sim_transfer_cube_human/data_ids_per_episode.pth b/tests/data/aloha_sim_transfer_cube_human/data_ids_per_episode.pth
new file mode 100644
index 00000000..a1d481dd
Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_human/data_ids_per_episode.pth differ
diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/action.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/action.memmap
deleted file mode 100644
index 9b4fef33..00000000
--- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/action.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:14fed0eed3d529a8ac0dd25a6d41585020772d02f9137fc9d604713b2f0f7076
-size 2800
diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/episode.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/episode.memmap
deleted file mode 100644
index af9fb07f..00000000
--- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/episode.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
-size 400
diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/frame_id.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/frame_id.memmap
deleted file mode 100644
index dc2f585c..00000000
--- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/frame_id.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
-size 400
diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/meta.json
deleted file mode 100644
index 2a0cf0a2..00000000
--- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/done.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/done.memmap
deleted file mode 100644
index 44fd709f..00000000
--- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/done.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
-size 50
diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/meta.json
deleted file mode 100644
index 3bfa9bd7..00000000
--- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/meta.json
deleted file mode 100644
index cb29a5ab..00000000
--- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/top.memmap
deleted file mode 100644
index cd2e7c06..00000000
--- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/top.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:2f713ea7fc19e592ea409a5e0bdfde403e5b86f834cbabe3463b791e8437fafc
-size 46080000
diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/meta.json
deleted file mode 100644
index 65ce1ca2..00000000
--- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/state.memmap
deleted file mode 100644
index 37feaad6..00000000
--- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/state.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:8c103e2c9d63c9f7cf9645bd24d9a2c4e8e08825dc75e230ebc793b8f9c213b0
-size 2800
diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/meta.json
deleted file mode 100644
index cb29a5ab..00000000
--- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/top.memmap
deleted file mode 100644
index 1188590c..00000000
--- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/top.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:7dbf4aa01b184d0eaa21ea999078d7cff86e1ca484a109614176fdc49f1ee05c
-size 46080000
diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/meta.json
deleted file mode 100644
index 65ce1ca2..00000000
--- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/state.memmap
deleted file mode 100644
index 9ef4cfd6..00000000
--- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/state.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:4fa0b9c870d4615037b6fee9e9e85e54d84352e173f2c7c1035232272fe2a3dd
-size 2800
diff --git a/tests/data/aloha_sim_transfer_cube_human/stats.pth b/tests/data/aloha_sim_transfer_cube_human/stats.pth
index 22f3e4d9..1ae356e3 100644
Binary files a/tests/data/aloha_sim_transfer_cube_human/stats.pth and b/tests/data/aloha_sim_transfer_cube_human/stats.pth differ
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/data_dict.pth b/tests/data/aloha_sim_transfer_cube_scripted/data_dict.pth
new file mode 100644
index 00000000..bd308bb0
Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_scripted/data_dict.pth differ
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/data_ids_per_episode.pth b/tests/data/aloha_sim_transfer_cube_scripted/data_ids_per_episode.pth
new file mode 100644
index 00000000..a1d481dd
Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_scripted/data_ids_per_episode.pth differ
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/action.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/action.memmap
deleted file mode 100644
index 8ac0c726..00000000
--- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/action.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:c0e199a82e2b7462e84406dbced5448a99f1dad9ce172771dfc3feb4b8597115
-size 2800
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/episode.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/episode.memmap
deleted file mode 100644
index af9fb07f..00000000
--- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/episode.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
-size 400
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/frame_id.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/frame_id.memmap
deleted file mode 100644
index dc2f585c..00000000
--- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/frame_id.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
-size 400
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/meta.json
deleted file mode 100644
index 2a0cf0a2..00000000
--- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/done.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/done.memmap
deleted file mode 100644
index 44fd709f..00000000
--- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/done.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
-size 50
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/meta.json
deleted file mode 100644
index 3bfa9bd7..00000000
--- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/meta.json
deleted file mode 100644
index cb29a5ab..00000000
--- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/top.memmap
deleted file mode 100644
index 8e5f533e..00000000
--- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/top.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:30b44d38cc4d68e06c716a875d39cbdbeacbfdc1657d6366f58c279efd27c52b
-size 46080000
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/meta.json
deleted file mode 100644
index 65ce1ca2..00000000
--- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/state.memmap
deleted file mode 100644
index e88320d1..00000000
--- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/state.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:9f484e7ea4f5f612dd53ee2c0f7891b8f7b2168a54fc81941ac2f2447260c294
-size 2800
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/meta.json
deleted file mode 100644
index cb29a5ab..00000000
--- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/top.memmap
deleted file mode 100644
index d415da0a..00000000
--- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/top.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:09600206f56cc5b52dfb896204b0044c4e830da368da141d7bd10e52181f6835
-size 46080000
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/meta.json
deleted file mode 100644
index 65ce1ca2..00000000
--- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/state.memmap
deleted file mode 100644
index be3436fb..00000000
--- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/state.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:60dcb547cf9a6372b78a455217a2408b6bece4371fba1df2a302b334d45c42a8
-size 2800
diff --git a/tests/data/aloha_sim_transfer_cube_scripted/stats.pth b/tests/data/aloha_sim_transfer_cube_scripted/stats.pth
index 63465344..71547f09 100644
Binary files a/tests/data/aloha_sim_transfer_cube_scripted/stats.pth and b/tests/data/aloha_sim_transfer_cube_scripted/stats.pth differ
diff --git a/tests/data/pusht/data_dict.pth b/tests/data/pusht/data_dict.pth
new file mode 100644
index 00000000..a083c86c
Binary files /dev/null and b/tests/data/pusht/data_dict.pth differ
diff --git a/tests/data/pusht/data_ids_per_episode.pth b/tests/data/pusht/data_ids_per_episode.pth
new file mode 100644
index 00000000..a1d481dd
Binary files /dev/null and b/tests/data/pusht/data_ids_per_episode.pth differ
diff --git a/tests/data/pusht/replay_buffer/action.memmap b/tests/data/pusht/replay_buffer/action.memmap
deleted file mode 100644
index f4127fb1..00000000
--- a/tests/data/pusht/replay_buffer/action.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:ba17d8e5c30151ea5f7f6fc31f19e12a68ce2113774b74c8aca0c7ef962a75f4
-size 400
diff --git a/tests/data/pusht/replay_buffer/episode.memmap b/tests/data/pusht/replay_buffer/episode.memmap
deleted file mode 100644
index af9fb07f..00000000
--- a/tests/data/pusht/replay_buffer/episode.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
-size 400
diff --git a/tests/data/pusht/replay_buffer/frame_id.memmap b/tests/data/pusht/replay_buffer/frame_id.memmap
deleted file mode 100644
index dc2f585c..00000000
--- a/tests/data/pusht/replay_buffer/frame_id.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
-size 400
diff --git a/tests/data/pusht/replay_buffer/meta.json b/tests/data/pusht/replay_buffer/meta.json
deleted file mode 100644
index 6f7c4218..00000000
--- a/tests/data/pusht/replay_buffer/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"action": {"device": "cpu", "shape": [50, 2], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/pusht/replay_buffer/next/done.memmap b/tests/data/pusht/replay_buffer/next/done.memmap
deleted file mode 100644
index 44fd709f..00000000
--- a/tests/data/pusht/replay_buffer/next/done.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
-size 50
diff --git a/tests/data/pusht/replay_buffer/next/meta.json b/tests/data/pusht/replay_buffer/next/meta.json
deleted file mode 100644
index b29a9ff7..00000000
--- a/tests/data/pusht/replay_buffer/next/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"reward": {"device": "cpu", "shape": [50, 1], "dtype": "torch.float32"}, "done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "success": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/pusht/replay_buffer/next/observation/image.memmap b/tests/data/pusht/replay_buffer/next/observation/image.memmap
deleted file mode 100644
index 68634378..00000000
--- a/tests/data/pusht/replay_buffer/next/observation/image.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:ff6a3748c8223a82e54c61442df7b8baf478a20497ee2353645a1e9ccd765162
-size 5529600
diff --git a/tests/data/pusht/replay_buffer/next/observation/meta.json b/tests/data/pusht/replay_buffer/next/observation/meta.json
deleted file mode 100644
index 57e0edea..00000000
--- a/tests/data/pusht/replay_buffer/next/observation/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"image": {"device": "cpu", "shape": [50, 3, 96, 96], "dtype": "torch.float32"}, "state": {"device": "cpu", "shape": [50, 2], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/pusht/replay_buffer/next/observation/state.memmap b/tests/data/pusht/replay_buffer/next/observation/state.memmap
deleted file mode 100644
index 8dd28f2a..00000000
--- a/tests/data/pusht/replay_buffer/next/observation/state.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:fad4ece6d5fd66bbafa34f6ff383c483410082b8d7d4f4616808c3c458ce1d43
-size 400
diff --git a/tests/data/pusht/replay_buffer/next/reward.memmap b/tests/data/pusht/replay_buffer/next/reward.memmap
deleted file mode 100644
index 109ed5ad..00000000
--- a/tests/data/pusht/replay_buffer/next/reward.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:6d9c54dee5660c46886f32d80e57e9dd0ffa57ee0cd2a762b036d9c8e0c3a33a
-size 200
diff --git a/tests/data/pusht/replay_buffer/next/success.memmap b/tests/data/pusht/replay_buffer/next/success.memmap
deleted file mode 100644
index 44fd709f..00000000
--- a/tests/data/pusht/replay_buffer/next/success.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
-size 50
diff --git a/tests/data/pusht/replay_buffer/observation/image.memmap b/tests/data/pusht/replay_buffer/observation/image.memmap
deleted file mode 100644
index 42c86ef0..00000000
--- a/tests/data/pusht/replay_buffer/observation/image.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:4bbde5cfd8cff9fd9fc6c9a57177f6fd31c8a03cf853b7d2234312f38380b0ba
-size 5529600
diff --git a/tests/data/pusht/replay_buffer/observation/meta.json b/tests/data/pusht/replay_buffer/observation/meta.json
deleted file mode 100644
index 57e0edea..00000000
--- a/tests/data/pusht/replay_buffer/observation/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"image": {"device": "cpu", "shape": [50, 3, 96, 96], "dtype": "torch.float32"}, "state": {"device": "cpu", "shape": [50, 2], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/pusht/replay_buffer/observation/state.memmap b/tests/data/pusht/replay_buffer/observation/state.memmap
deleted file mode 100644
index 3ac8e4ab..00000000
--- a/tests/data/pusht/replay_buffer/observation/state.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:67c7e39090a16546fb1eade833d704f26464d574d7e431415f828159a154d2bf
-size 400
diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth
index d7107185..636985fd 100644
Binary files a/tests/data/pusht/stats.pth and b/tests/data/pusht/stats.pth differ
diff --git a/tests/data/xarm_lift_medium/data_dict.pth b/tests/data/xarm_lift_medium/data_dict.pth
new file mode 100644
index 00000000..5c166576
Binary files /dev/null and b/tests/data/xarm_lift_medium/data_dict.pth differ
diff --git a/tests/data/xarm_lift_medium/data_ids_per_episode.pth b/tests/data/xarm_lift_medium/data_ids_per_episode.pth
new file mode 100644
index 00000000..21095017
Binary files /dev/null and b/tests/data/xarm_lift_medium/data_ids_per_episode.pth differ
diff --git a/tests/data/xarm_lift_medium/replay_buffer/action.memmap b/tests/data/xarm_lift_medium/replay_buffer/action.memmap
deleted file mode 100644
index c90afbe9..00000000
--- a/tests/data/xarm_lift_medium/replay_buffer/action.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:10ec2f944de18f1a2aa3fc2f9a4185c03e3a5afc31148c85c98b58602ac4186e
-size 800
diff --git a/tests/data/xarm_lift_medium/replay_buffer/episode.memmap b/tests/data/xarm_lift_medium/replay_buffer/episode.memmap
deleted file mode 100644
index 7924f028..00000000
--- a/tests/data/xarm_lift_medium/replay_buffer/episode.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:1a589cba6bf0dfce138110864b6509508a804d7ea5c519d0a3cd67c4a87fa2d0
-size 200
diff --git a/tests/data/xarm_lift_medium/replay_buffer/frame_id.memmap b/tests/data/xarm_lift_medium/replay_buffer/frame_id.memmap
deleted file mode 100644
index a633d346..00000000
--- a/tests/data/xarm_lift_medium/replay_buffer/frame_id.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:6afe7098f30bdc8564526517c085e62613f6cb67194153840567974cfa6f3815
-size 400
diff --git a/tests/data/xarm_lift_medium/replay_buffer/meta.json b/tests/data/xarm_lift_medium/replay_buffer/meta.json
deleted file mode 100644
index 33dc932c..00000000
--- a/tests/data/xarm_lift_medium/replay_buffer/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"action": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int32"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/done.memmap b/tests/data/xarm_lift_medium/replay_buffer/next/done.memmap
deleted file mode 100644
index cf5e9cca..00000000
--- a/tests/data/xarm_lift_medium/replay_buffer/next/done.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:dab3a9712c413c4bfcd91c645752ab981306b23d25bcd4da4c422911574673f7
-size 50
diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/meta.json b/tests/data/xarm_lift_medium/replay_buffer/next/meta.json
deleted file mode 100644
index d69cadad..00000000
--- a/tests/data/xarm_lift_medium/replay_buffer/next/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"reward": {"device": "cpu", "shape": [50], "dtype": "torch.float32"}, "done": {"device": "cpu", "shape": [50], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/observation/image.memmap b/tests/data/xarm_lift_medium/replay_buffer/next/observation/image.memmap
deleted file mode 100644
index 462d0117..00000000
--- a/tests/data/xarm_lift_medium/replay_buffer/next/observation/image.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:d6f9d1422ce3764e7253f70ed4da278f0c07fafef0d5386479f09d6b6b9d8259
-size 1058400
diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/observation/meta.json b/tests/data/xarm_lift_medium/replay_buffer/next/observation/meta.json
deleted file mode 100644
index b13b8ec9..00000000
--- a/tests/data/xarm_lift_medium/replay_buffer/next/observation/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"image": {"device": "cpu", "shape": [50, 3, 84, 84], "dtype": "torch.uint8"}, "state": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/observation/state.memmap b/tests/data/xarm_lift_medium/replay_buffer/next/observation/state.memmap
deleted file mode 100644
index 1dbe6024..00000000
--- a/tests/data/xarm_lift_medium/replay_buffer/next/observation/state.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:52e7c1a3c4fb2423b195e66ffee2c9e23b3ea0ad7c8bfc4dec30a35c65cadcbb
-size 800
diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/reward.memmap b/tests/data/xarm_lift_medium/replay_buffer/next/reward.memmap
deleted file mode 100644
index 9ff5d5a1..00000000
--- a/tests/data/xarm_lift_medium/replay_buffer/next/reward.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:c4dbe8ea1966e5cc6da6daf5704805b9b5810f4575de7016b8f6cb1da1d7bb8a
-size 200
diff --git a/tests/data/xarm_lift_medium/replay_buffer/observation/image.memmap b/tests/data/xarm_lift_medium/replay_buffer/observation/image.memmap
deleted file mode 100644
index c9416940..00000000
--- a/tests/data/xarm_lift_medium/replay_buffer/observation/image.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:8fca8ddbda3f7bb2f6e7553297c18f3ab8f8b73d64b5c9f81a3695ad9379d403
-size 1058400
diff --git a/tests/data/xarm_lift_medium/replay_buffer/observation/meta.json b/tests/data/xarm_lift_medium/replay_buffer/observation/meta.json
deleted file mode 100644
index b13b8ec9..00000000
--- a/tests/data/xarm_lift_medium/replay_buffer/observation/meta.json
+++ /dev/null
@@ -1 +0,0 @@
-{"image": {"device": "cpu", "shape": [50, 3, 84, 84], "dtype": "torch.uint8"}, "state": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""}
\ No newline at end of file
diff --git a/tests/data/xarm_lift_medium/replay_buffer/observation/state.memmap b/tests/data/xarm_lift_medium/replay_buffer/observation/state.memmap
deleted file mode 100644
index 3bae16df..00000000
--- a/tests/data/xarm_lift_medium/replay_buffer/observation/state.memmap
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:7b3e3e12896d553c208ee152f6d447c877c435e15d010c4a6171966d5b8a0c0b
-size 800
diff --git a/tests/data/xarm_lift_medium/stats.pth b/tests/data/xarm_lift_medium/stats.pth
index 0accffb0..3ab4e05b 100644
Binary files a/tests/data/xarm_lift_medium/stats.pth and b/tests/data/xarm_lift_medium/stats.pth differ
diff --git a/tests/scripts/mock_dataset.py b/tests/scripts/mock_dataset.py
index d9c86464..72480666 100644
--- a/tests/scripts/mock_dataset.py
+++ b/tests/scripts/mock_dataset.py
@@ -18,28 +18,33 @@ Example:
import argparse
import shutil
-from tensordict import TensorDict
from pathlib import Path
+import torch
+
def mock_dataset(in_data_dir, out_data_dir, num_frames):
in_data_dir = Path(in_data_dir)
out_data_dir = Path(out_data_dir)
+ out_data_dir.mkdir(exist_ok=True, parents=True)
- # load full dataset as a tensor dict
- in_td_data = TensorDict.load_memmap(in_data_dir / "replay_buffer")
+ # copy the first `n` frames for each data key so that we have real data
+ in_data_dict = torch.load(in_data_dir / "data_dict.pth")
+ out_data_dict = {key: in_data_dict[key][:num_frames].clone() for key in in_data_dict}
+ torch.save(out_data_dict, out_data_dir / "data_dict.pth")
- # use 1 frame to know the specification of the dataset
- # and copy it over `n` frames in the test artifact directory
- out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir / "replay_buffer")
+ # recreate data_ids_per_episode that corresponds to the subset
+ episodes = in_data_dict["episode"][:num_frames].tolist()
+ data_ids_per_episode = {}
+ for idx, ep_id in enumerate(episodes):
+ if ep_id not in data_ids_per_episode:
+ data_ids_per_episode[ep_id] = []
+ data_ids_per_episode[ep_id].append(idx)
+ for ep_id in data_ids_per_episode:
+ data_ids_per_episode[ep_id] = torch.tensor(data_ids_per_episode[ep_id])
+ torch.save(data_ids_per_episode, out_data_dir / "data_ids_per_episode.pth")
- # copy the first `n` frames so that we have real data
- out_td_data[:num_frames] = in_td_data[:num_frames].clone()
-
- # make sure everything has been properly written
- out_td_data.lock_()
-
- # copy the full statistics of dataset since it's pretty small
+ # copy the full statistics of dataset since it's small
in_stats_path = in_data_dir / "stats.pth"
out_stats_path = out_data_dir / "stats.pth"
shutil.copy(in_stats_path, out_stats_path)
diff --git a/tests/test_available.py b/tests/test_available.py
index 9cc91efa..be74a42a 100644
--- a/tests/test_available.py
+++ b/tests/test_available.py
@@ -1,25 +1,20 @@
"""
This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully
-imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) corresponds.
+imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid.
-Note:
- When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
- 1. set the required class attributes:
- - for classes inheriting from `AbstractDataset`: `available_datasets`
- - for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- - for classes inheriting from `AbstractPolicy`: `name`
- 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- 3. update variables in `tests/test_available.py` by importing your new class
+When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
+- Set the required class attributes: `available_datasets`.
+- Set the required class attributes: `name`.
+- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
+- Update variables in `tests/test_available.py` by importing your new class
"""
+import importlib
import pytest
import lerobot
+import gymnasium as gym
-from lerobot.common.envs.aloha.env import AlohaEnv
-from lerobot.common.envs.pusht.env import PushtEnv
-from lerobot.common.envs.simxarm.env import SimxarmEnv
-
-from lerobot.common.datasets.simxarm import SimxarmDataset
+from lerobot.common.datasets.xarm import XarmDataset
from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.common.datasets.pusht import PushtDataset
@@ -29,36 +24,30 @@ from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
def test_available():
- pol_classes = [
+ policy_classes = [
ActionChunkingTransformerPolicy,
DiffusionPolicy,
TDMPCPolicy,
]
- env_classes = [
- AlohaEnv,
- PushtEnv,
- SimxarmEnv,
- ]
-
- dat_classes = [
- AlohaDataset,
- PushtDataset,
- SimxarmDataset,
- ]
+ dataset_class_per_env = {
+ "aloha": AlohaDataset,
+ "pusht": PushtDataset,
+ "xarm": XarmDataset,
+ }
- policies = [pol_cls.name for pol_cls in pol_classes]
- assert set(policies) == set(lerobot.available_policies)
+ policies = [pol_cls.name for pol_cls in policy_classes]
+ assert set(policies) == set(lerobot.available_policies), policies
- envs = [env_cls.name for env_cls in env_classes]
- assert set(envs) == set(lerobot.available_envs)
+ for env_name in lerobot.available_envs:
+ for task_name in lerobot.available_tasks_per_env[env_name]:
+ package_name = f"gym_{env_name}"
+ importlib.import_module(package_name)
+ gym_handle = f"{package_name}/{task_name}"
+ assert gym_handle in gym.envs.registry.keys(), gym_handle
- tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes}
- for env in envs:
- assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env])
-
- datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)}
- for env in envs:
- assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env])
+ dataset_class = dataset_class_per_env[env_name]
+ available_datasets = lerobot.available_datasets_per_env[env_name]
+ assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}"
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index df41b03f..71eefa9c 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -1,38 +1,91 @@
+import os
+from pathlib import Path
import einops
import pytest
import torch
-from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
-from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
-from lerobot.common.datasets.factory import make_offline_buffer
+from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, load_data_with_delta_timestamps
+from lerobot.common.datasets.xarm import XarmDataset
+from lerobot.common.transforms import Prod
from lerobot.common.utils import init_hydra_config
+import logging
+from lerobot.common.datasets.factory import make_dataset
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
- "env_name,dataset_id",
+ "env_name,dataset_id,policy_name",
[
- ("simxarm", "lift"),
- ("pusht", "pusht"),
- ("aloha", "sim_insertion_human"),
- ("aloha", "sim_insertion_scripted"),
- ("aloha", "sim_transfer_cube_human"),
- ("aloha", "sim_transfer_cube_scripted"),
+ ("xarm", "xarm_lift_medium", "tdmpc"),
+ ("pusht", "pusht", "diffusion"),
+ ("aloha", "aloha_sim_insertion_human", "act"),
+ ("aloha", "aloha_sim_insertion_scripted", "act"),
+ ("aloha", "aloha_sim_transfer_cube_human", "act"),
+ ("aloha", "aloha_sim_transfer_cube_scripted", "act"),
],
)
-def test_factory(env_name, dataset_id):
+def test_factory(env_name, dataset_id, policy_name):
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
- overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]
+ overrides=[f"env={env_name}", f"dataset_id={dataset_id}", f"policy={policy_name}", f"device={DEVICE}"]
)
- offline_buffer = make_offline_buffer(cfg)
- for key in offline_buffer.image_keys:
- img = offline_buffer[0].get(key)
- assert img.dtype == torch.float32
- # TODO(rcadene): we assume for now that image normalization takes place in the model
- assert img.max() <= 1.0
- assert img.min() >= 0.0
+ dataset = make_dataset(cfg)
+ delta_timestamps = dataset.delta_timestamps
+ image_keys = dataset.image_keys
+
+ item = dataset[0]
+
+ keys_ndim_required = [
+ ("action", 1, True),
+ ("episode", 0, True),
+ ("frame_id", 0, True),
+ ("timestamp", 0, True),
+ # TODO(rcadene): should we rename it agent_pos?
+ ("observation.state", 1, True),
+ ("next.reward", 0, False),
+ ("next.done", 0, False),
+ ]
+
+ for key in image_keys:
+ keys_ndim_required.append(
+ (key, 3, True),
+ )
+ assert dataset.data_dict[key].dtype == torch.uint8, f"{key}"
+
+ # test number of dimensions
+ for key, ndim, required in keys_ndim_required:
+ if key not in item:
+ if required:
+ assert key in item, f"{key}"
+ else:
+ logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.')
+ continue
+
+ if delta_timestamps is not None and key in delta_timestamps:
+ assert item[key].ndim == ndim + 1, f"{key}"
+ assert item[key].shape[0] == len(delta_timestamps[key]), f"{key}"
+ else:
+ assert item[key].ndim == ndim, f"{key}"
+
+ if key in image_keys:
+ assert item[key].dtype == torch.float32, f"{key}"
+ # TODO(rcadene): we assume for now that image normalization takes place in the model
+ assert item[key].max() <= 1.0, f"{key}"
+ assert item[key].min() >= 0.0, f"{key}"
+
+ if delta_timestamps is not None and key in delta_timestamps:
+ # test t,c,h,w
+ assert item[key].shape[1] == 3, f"{key}"
+ else:
+ # test c,h,w
+ assert item[key].shape[0] == 3, f"{key}"
+
+
+ if delta_timestamps is not None:
+ # test missing keys in delta_timestamps
+ for key in delta_timestamps:
+ assert key in item, f"{key}"
def test_compute_stats():
@@ -41,26 +94,98 @@ def test_compute_stats():
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset).
"""
- cfg = init_hydra_config(
- DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"]
+ DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
+
+ # get transform to convert images from uint8 [0,255] to float32 [0,1]
+ transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
+
+ dataset = XarmDataset(
+ dataset_id="xarm_lift_medium",
+ root=DATA_DIR,
+ transform=transform,
)
- buffer = make_offline_buffer(cfg)
- # Get all of the data.
- all_data = TensorDictReplayBuffer(
- storage=buffer._storage,
- batch_size=len(buffer),
- sampler=SamplerWithoutReplacement(),
- ).sample().float()
+
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
- computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75))
- for k, pattern in buffer.stats_patterns.items():
- expected_mean = einops.reduce(all_data[k], pattern, "mean")
- assert torch.allclose(computed_stats[k]["mean"], expected_mean)
- assert torch.allclose(
- computed_stats[k]["std"],
- torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))
- )
- assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
- assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))
+ computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
+
+ # get einops patterns to aggregate batches and compute statistics
+ stats_patterns = get_stats_einops_patterns(dataset)
+
+ # get all frames from the dataset in the same dtype and range as during compute_stats
+ data_dict = transform(dataset.data_dict)
+
+ # compute stats based on all frames from the dataset without any batching
+ expected_stats = {}
+ for k, pattern in stats_patterns.items():
+ expected_stats[k] = {}
+ expected_stats[k]["mean"] = einops.reduce(data_dict[k], pattern, "mean")
+ expected_stats[k]["std"] = torch.sqrt(einops.reduce((data_dict[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean"))
+ expected_stats[k]["min"] = einops.reduce(data_dict[k], pattern, "min")
+ expected_stats[k]["max"] = einops.reduce(data_dict[k], pattern, "max")
+
+ # test computed stats match expected stats
+ for k in stats_patterns:
+ assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"])
+ assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"])
+ assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
+ assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
+
+ # TODO(rcadene): check that the stats used for training are correct too
+ # # load stats that are expected to match the ones returned by computed_stats
+ # assert (dataset.data_dir / "stats.pth").exists()
+ # loaded_stats = torch.load(dataset.data_dir / "stats.pth")
+
+ # # test loaded stats match expected stats
+ # for k in stats_patterns:
+ # assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
+ # assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
+ # assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
+ # assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
+
+
+def test_load_data_with_delta_timestamps_within_tolerance():
+ data_dict = {
+ "timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
+ "index": torch.tensor([0, 1, 2, 3, 4]),
+ }
+ data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
+ delta_timestamps = {"index": [-0.2, 0, 0.139]}
+ key = "index"
+ current_ts = 0.3
+ episode = 0
+ tol = 0.04
+ data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
+ assert not is_pad.any(), "Unexpected padding detected"
+ assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
+
+def test_load_data_with_delta_timestamps_outside_tolerance_inside_episode_range():
+ data_dict = {
+ "timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
+ "index": torch.tensor([0, 1, 2, 3, 4]),
+ }
+ data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
+ delta_timestamps = {"index": [-0.2, 0, 0.141]}
+ key = "index"
+ current_ts = 0.3
+ episode = 0
+ tol = 0.04
+ with pytest.raises(AssertionError):
+ load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
+
+def test_load_data_with_delta_timestamps_outside_tolerance_outside_episode_range():
+ data_dict = {
+ "timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
+ "index": torch.tensor([0, 1, 2, 3, 4]),
+ }
+ data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
+ delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
+ key = "index"
+ current_ts = 0.3
+ episode = 0
+ tol = 0.04
+ data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
+ assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values"
+ assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
+
diff --git a/tests/test_envs.py b/tests/test_envs.py
index eb3746db..d25231b0 100644
--- a/tests/test_envs.py
+++ b/tests/test_envs.py
@@ -1,112 +1,46 @@
+import importlib
import pytest
-from tensordict import TensorDict
import torch
-from torchrl.envs.utils import check_env_specs, step_mdp
-from lerobot.common.datasets.factory import make_offline_buffer
+from lerobot.common.datasets.factory import make_dataset
+import gymnasium as gym
+from gymnasium.utils.env_checker import check_env
-from lerobot.common.envs.aloha.env import AlohaEnv
from lerobot.common.envs.factory import make_env
-from lerobot.common.envs.pusht.env import PushtEnv
-from lerobot.common.envs.simxarm.env import SimxarmEnv
from lerobot.common.utils import init_hydra_config
+from lerobot.common.envs.utils import preprocess_observation
+
from .utils import DEVICE, DEFAULT_CONFIG_PATH
-def print_spec_rollout(env):
- print("observation_spec:", env.observation_spec)
- print("action_spec:", env.action_spec)
- print("reward_spec:", env.reward_spec)
- print("done_spec:", env.done_spec)
-
- td = env.reset()
- print("reset tensordict", td)
-
- td = env.rand_step(td)
- print("random step tensordict", td)
-
- def simple_rollout(steps=100):
- # preallocate:
- data = TensorDict({}, [steps])
- # reset
- _data = env.reset()
- for i in range(steps):
- _data["action"] = env.action_spec.rand()
- _data = env.step(_data)
- data[i] = _data
- _data = step_mdp(_data, keep_other=True)
- return data
-
- print("data from rollout:", simple_rollout(100))
-
-
@pytest.mark.parametrize(
- "task,from_pixels,pixels_only",
+ "env_name, task, obs_type",
[
- ("sim_insertion", True, False),
- ("sim_insertion", True, True),
- ("sim_transfer_cube", True, False),
- ("sim_transfer_cube", True, True),
+ # ("AlohaInsertion-v0", "state"),
+ ("aloha", "AlohaInsertion-v0", "pixels"),
+ ("aloha", "AlohaInsertion-v0", "pixels_agent_pos"),
+ ("aloha", "AlohaTransferCube-v0", "pixels"),
+ ("aloha", "AlohaTransferCube-v0", "pixels_agent_pos"),
+ ("xarm", "XarmLift-v0", "state"),
+ ("xarm", "XarmLift-v0", "pixels"),
+ ("xarm", "XarmLift-v0", "pixels_agent_pos"),
+ ("pusht", "PushT-v0", "state"),
+ ("pusht", "PushT-v0", "pixels"),
+ ("pusht", "PushT-v0", "pixels_agent_pos"),
],
)
-def test_aloha(task, from_pixels, pixels_only):
- env = AlohaEnv(
- task,
- from_pixels=from_pixels,
- pixels_only=pixels_only,
- image_size=[3, 480, 640] if from_pixels else None,
- )
- # print_spec_rollout(env)
- check_env_specs(env)
-
-
-@pytest.mark.parametrize(
- "task,from_pixels,pixels_only",
- [
- ("lift", False, False),
- ("lift", True, False),
- ("lift", True, True),
- # TODO(aliberts): Add simxarm other tasks
- # ("reach", False, False),
- # ("reach", True, False),
- # ("push", False, False),
- # ("push", True, False),
- # ("peg_in_box", False, False),
- # ("peg_in_box", True, False),
- ],
-)
-def test_simxarm(task, from_pixels, pixels_only):
- env = SimxarmEnv(
- task,
- from_pixels=from_pixels,
- pixels_only=pixels_only,
- image_size=84 if from_pixels else None,
- )
- # print_spec_rollout(env)
- check_env_specs(env)
-
-
-@pytest.mark.parametrize(
- "from_pixels,pixels_only",
- [
- (True, False),
- ],
-)
-def test_pusht(from_pixels, pixels_only):
- env = PushtEnv(
- from_pixels=from_pixels,
- pixels_only=pixels_only,
- image_size=96 if from_pixels else None,
- )
- # print_spec_rollout(env)
- check_env_specs(env)
-
+def test_env(env_name, task, obs_type):
+ package_name = f"gym_{env_name}"
+ importlib.import_module(package_name)
+ env = gym.make(f"{package_name}/{task}", obs_type=obs_type)
+ check_env(env.unwrapped, skip_render_check=True)
+ env.close()
@pytest.mark.parametrize(
"env_name",
[
- "simxarm",
"pusht",
+ "xarm",
"aloha",
],
)
@@ -116,18 +50,16 @@ def test_factory(env_name):
overrides=[f"env={env_name}", f"device={DEVICE}"],
)
- offline_buffer = make_offline_buffer(cfg)
+ dataset = make_dataset(cfg)
- env = make_env(cfg)
- for key in offline_buffer.image_keys:
- assert env.reset().get(key).dtype == torch.uint8
- check_env_specs(env)
-
- env = make_env(cfg, transform=offline_buffer.transform)
- for key in offline_buffer.image_keys:
- img = env.reset().get(key)
+ env = make_env(cfg, num_parallel_envs=1)
+ obs, _ = env.reset()
+ obs = preprocess_observation(obs, transform=dataset.transform)
+ for key in dataset.image_keys:
+ img = obs[key]
assert img.dtype == torch.float32
# TODO(rcadene): we assume for now that image normalization takes place in the model
assert img.max() <= 1.0
assert img.min() >= 0.0
- check_env_specs(env)
+
+ env.close()
diff --git a/tests/test_policies.py b/tests/test_policies.py
index 5d6b46d0..8ccc7c62 100644
--- a/tests/test_policies.py
+++ b/tests/test_policies.py
@@ -1,37 +1,35 @@
import pytest
-from tensordict import TensorDict
-from tensordict.nn import TensorDictModule
import torch
-from torchrl.data import UnboundedContinuousTensorSpec
-from torchrl.envs import EnvBase
+from lerobot.common.datasets.utils import cycle
+from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.envs.factory import make_env
-from lerobot.common.datasets.factory import make_offline_buffer
-from lerobot.common.policies.abstract import AbstractPolicy
+from lerobot.common.datasets.factory import make_dataset
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
- ("simxarm", "tdmpc", ["policy.mpc=true"]),
+ ("xarm", "tdmpc", ["policy.mpc=true"]),
("pusht", "tdmpc", ["policy.mpc=false"]),
("pusht", "diffusion", []),
- ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
- ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
- ("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
- ("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
- # TODO(aliberts): simxarm not working with diffusion
- # ("simxarm", "diffusion", []),
+ ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]),
+ ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
+ ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
+ ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
+ # TODO(aliberts): xarm not working with diffusion
+ # ("xarm", "diffusion", []),
],
)
-def test_concrete_policy(env_name, policy_name, extra_overrides):
+def test_policy(env_name, policy_name, extra_overrides):
"""
Tests:
- Making the policy object.
- Updating the policy.
- Using the policy to select actions at inference time.
+ - Test the action can be applied to the policy
"""
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
@@ -45,92 +43,44 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
# Check that we can make the policy object.
policy = make_policy(cfg)
# Check that we run select_actions and get the appropriate output.
- offline_buffer = make_offline_buffer(cfg)
- env = make_env(cfg, transform=offline_buffer.transform)
+ dataset = make_dataset(cfg)
+ env = make_env(cfg, num_parallel_envs=2)
- if env_name != "aloha":
- # TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError:
- # seq_length as a list is not supported for now.
- policy.update(offline_buffer, torch.tensor(0, device=DEVICE))
-
- action = policy(
- env.observation_spec.rand()["observation"].to(DEVICE),
- torch.tensor(0, device=DEVICE),
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ num_workers=4,
+ batch_size=2,
+ shuffle=True,
+ pin_memory=DEVICE != "cpu",
+ drop_last=True,
)
- assert action.shape == env.action_spec.shape
+ dl_iter = cycle(dataloader)
+ batch = next(dl_iter)
-def test_abstract_policy_forward():
- """
- Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that:
- - The policy is invoked the expected number of times during a rollout.
- - The environment's termination condition is respected even when part way through an action trajectory.
- - The observations are returned correctly.
- """
+ for key in batch:
+ batch[key] = batch[key].to(DEVICE, non_blocking=True)
- n_action_steps = 8 # our test policy will output 8 action step horizons
- terminate_at = 10 # some number that is more than n_action_steps but not a multiple
- rollout_max_steps = terminate_at + 1 # some number greater than terminate_at
+ # Test updating the policy
+ policy(batch, step=0)
- # A minimal environment for testing.
- class StubEnv(EnvBase):
+ # reset the policy and environment
+ policy.reset()
+ observation, _ = env.reset(seed=cfg.seed)
- def __init__(self):
- super().__init__()
- self.action_spec = UnboundedContinuousTensorSpec(shape=(1,))
- self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
+ # apply transform to normalize the observations
+ observation = preprocess_observation(observation, dataset.transform)
- def _step(self, tensordict: TensorDict) -> TensorDict:
- self.invocation_count += 1
- return TensorDict(
- {
- "observation": torch.tensor([self.invocation_count]),
- "reward": torch.tensor([self.invocation_count]),
- "terminated": torch.tensor(
- tensordict["action"].item() == terminate_at
- ),
- }
- )
+ # send observation to device/gpu
+ observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
- def _reset(self, tensordict: TensorDict) -> TensorDict:
- self.invocation_count = 0
- return TensorDict(
- {
- "observation": torch.tensor([self.invocation_count]),
- "reward": torch.tensor([self.invocation_count]),
- }
- )
+ # get the next action for the environment
+ with torch.inference_mode():
+ action = policy.select_action(observation, step=0)
- def _set_seed(self, seed: int | None):
- return
+ # apply inverse transform to unnormalize the action
+ action = postprocess_action(action, dataset.transform)
- class StubPolicy(AbstractPolicy):
- name = "stub"
+ # Test step through policy
+ env.step(action)
- def __init__(self):
- super().__init__(n_action_steps)
- self.n_policy_invocations = 0
-
- def update(self):
- pass
-
- def select_actions(self):
- self.n_policy_invocations += 1
- return torch.stack(
- [torch.tensor([i]) for i in range(self.n_action_steps)]
- ).unsqueeze(0)
-
- env = StubEnv()
- policy = StubPolicy()
- policy = TensorDictModule(
- policy,
- in_keys=[],
- out_keys=["action"],
- )
-
- # Keep track to make sure the policy is called the expected number of times
- rollout = env.rollout(rollout_max_steps, policy)
-
- assert len(rollout) == terminate_at + 1 # +1 for the reset observation
- assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
- assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))