diff --git a/.github/poetry/cpu/poetry.lock b/.github/poetry/cpu/poetry.lock index ba820f34..98a3d58d 100644 --- a/.github/poetry/cpu/poetry.lock +++ b/.github/poetry/cpu/poetry.lock @@ -517,21 +517,11 @@ files = [ {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, ] -[[package]] -name = "dm" -version = "1.3" -description = "Dict to Data mapper" -optional = false -python-versions = "*" -files = [ - {file = "dm-1.3.tar.gz", hash = "sha256:ce77537bf346b5d8c0dc0b5d679cfc4a946faadcd5315e6c80ef6f3af824130d"}, -] - [[package]] name = "dm-control" version = "1.0.14" description = "Continuous control environments and MuJoCo Python bindings." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "dm_control-1.0.14-py3-none-any.whl", hash = "sha256:883c63244a7ebf598700a97564ed19fffd3479ca79efd090aed881609cdb9fc6"}, @@ -562,7 +552,7 @@ hdf5 = ["h5py"] name = "dm-env" version = "1.6" description = "A Python interface for Reinforcement Learning environments." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "dm-env-1.6.tar.gz", hash = "sha256:a436eb1c654c39e0c986a516cee218bea7140b510fceff63f97eb4fcff3d93de"}, @@ -578,7 +568,7 @@ numpy = "*" name = "dm-tree" version = "0.1.8" description = "Tree is a library for working with nested data structures." -optional = false +optional = true python-versions = "*" files = [ {file = "dm-tree-0.1.8.tar.gz", hash = "sha256:0fcaabbb14e7980377439e7140bd05552739ca5e515ecb3119f234acee4b9430"}, @@ -806,7 +796,7 @@ test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre name = "glfw" version = "2.7.0" description = "A ctypes-based wrapper for GLFW3." -optional = false +optional = true python-versions = "*" files = [ {file = "glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-macosx_10_6_intel.whl", hash = "sha256:bd82849edcceda4e262bd1227afaa74b94f9f0731c1197863cd25c15bfc613fc"}, @@ -889,6 +879,69 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.62.1)"] +[[package]] +name = "gym-aloha" +version = "0.1.0" +description = "A gym environment for ALOHA" +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +dm-control = "1.0.14" +gymnasium = "^0.29.1" +mujoco = "^2.3.7" + +[package.source] +type = "git" +url = "git@github.com:huggingface/gym-aloha.git" +reference = "HEAD" +resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f" + +[[package]] +name = "gym-pusht" +version = "0.1.0" +description = "A gymnasium environment for PushT." +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +gymnasium = "^0.29.1" +opencv-python = "^4.9.0.80" +pygame = "^2.5.2" +pymunk = "^6.6.0" +scikit-image = "^0.22.0" +shapely = "^2.0.3" + +[package.source] +type = "git" +url = "git@github.com:huggingface/gym-pusht.git" +reference = "HEAD" +resolved_reference = "6c9893504f670ff069d0f759a733e971ea1efdbf" + +[[package]] +name = "gym-xarm" +version = "0.1.0" +description = "A gym environment for xArm" +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +gymnasium = "^0.29.1" +gymnasium-robotics = "^1.2.4" +mujoco = "^2.3.7" + +[package.source] +type = "git" +url = "git@github.com:huggingface/gym-xarm.git" +reference = "HEAD" +resolved_reference = "6a88f7d63833705dfbec4b997bf36cac6b4a448c" + [[package]] name = "gymnasium" version = "0.29.1" @@ -923,7 +976,7 @@ toy-text = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"] name = "gymnasium-robotics" version = "1.2.4" description = "Robotics environments for the Gymnasium repo." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "gymnasium-robotics-1.2.4.tar.gz", hash = "sha256:d304192b066f8b800599dfbe3d9d90bba9b761ee884472bdc4d05968a8bc61cb"}, @@ -1155,7 +1208,7 @@ i18n = ["Babel (>=2.7)"] name = "labmaze" version = "1.0.6" description = "LabMaze: DeepMind Lab's text maze generator." -optional = false +optional = true python-versions = "*" files = [ {file = "labmaze-1.0.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b2ddef976dfd8d992b19cfa6c633f2eba7576d759c2082da534e3f727479a84a"}, @@ -1199,7 +1252,7 @@ setuptools = "!=50.0.0" name = "lazy-loader" version = "0.3" description = "lazy_loader" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "lazy_loader-0.3-py3-none-any.whl", hash = "sha256:1e9e76ee8631e264c62ce10006718e80b2cfc74340d17d1031e0f84af7478554"}, @@ -1244,7 +1297,7 @@ files = [ name = "lxml" version = "5.1.0" description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." -optional = false +optional = true python-versions = ">=3.6" files = [ {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"}, @@ -1462,7 +1515,7 @@ tests = ["pytest (>=4.6)"] name = "mujoco" version = "2.3.7" description = "MuJoCo Physics Simulator" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"}, @@ -1776,7 +1829,7 @@ xml = ["lxml (>=4.9.2)"] name = "pettingzoo" version = "1.24.3" description = "Gymnasium for multi-agent reinforcement learning." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "pettingzoo-1.24.3-py3-none-any.whl", hash = "sha256:23ed90517d2e8a7098bdaf5e31234b3a7f7b73ca578d70d1ca7b9d0cb0e37982"}, @@ -2144,7 +2197,7 @@ dev = ["aafigure", "matplotlib", "pygame", "pyglet (<2.0.0)", "sphinx", "wheel"] name = "pyopengl" version = "3.1.7" description = "Standard OpenGL bindings for Python" -optional = false +optional = true python-versions = "*" files = [ {file = "PyOpenGL-3.1.7-py3-none-any.whl", hash = "sha256:a6ab19cf290df6101aaf7470843a9c46207789855746399d0af92521a0a92b7a"}, @@ -2155,7 +2208,7 @@ files = [ name = "pyparsing" version = "3.1.2" description = "pyparsing module - Classes and methods to define and execute parsing grammars" -optional = false +optional = true python-versions = ">=3.6.8" files = [ {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, @@ -2586,7 +2639,7 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"] name = "scikit-image" version = "0.22.0" description = "Image processing in Python" -optional = false +optional = true python-versions = ">=3.9" files = [ {file = "scikit_image-0.22.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:74ec5c1d4693506842cc7c9487c89d8fc32aed064e9363def7af08b8f8cbb31d"}, @@ -2634,7 +2687,7 @@ test = ["asv", "matplotlib (>=3.5)", "numpydoc (>=1.5)", "pooch (>=1.6.0)", "pyt name = "scipy" version = "1.12.0" description = "Fundamental algorithms for scientific computing in Python" -optional = false +optional = true python-versions = ">=3.9" files = [ {file = "scipy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:78e4402e140879387187f7f25d91cc592b3501a2e51dfb320f48dfb73565f10b"}, @@ -2839,7 +2892,7 @@ testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jar name = "shapely" version = "2.0.3" description = "Manipulation and analysis of geometric objects" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "shapely-2.0.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:af7e9abe180b189431b0f490638281b43b84a33a960620e6b2e8d3e3458b61a1"}, @@ -2988,31 +3041,6 @@ numpy = "*" packaging = "*" protobuf = ">=3.20" -[[package]] -name = "tensordict" -version = "0.4.0+b4c91e8" -description = "" -optional = false -python-versions = "*" -files = [] -develop = false - -[package.dependencies] -cloudpickle = "*" -numpy = "*" -torch = ">=2.1.0" - -[package.extras] -checkpointing = ["torchsnapshot-nightly"] -h5 = ["h5py (>=3.8)"] -tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures", "pyyaml"] - -[package.source] -type = "git" -url = "https://github.com/pytorch/tensordict" -reference = "HEAD" -resolved_reference = "b4c91e8828c538ca0a50d8383fd99311a9afb078" - [[package]] name = "termcolor" version = "2.4.0" @@ -3031,7 +3059,7 @@ tests = ["pytest", "pytest-cov"] name = "tifffile" version = "2024.2.12" description = "Read and write TIFF files" -optional = false +optional = true python-versions = ">=3.9" files = [ {file = "tifffile-2024.2.12-py3-none-any.whl", hash = "sha256:870998f82fbc94ff7c3528884c1b0ae54863504ff51dbebea431ac3fa8fb7c21"}, @@ -3091,40 +3119,6 @@ type = "legacy" url = "https://download.pytorch.org/whl/cpu" reference = "torch-cpu" -[[package]] -name = "torchrl" -version = "0.4.0+13bef42" -description = "" -optional = false -python-versions = "*" -files = [] -develop = false - -[package.dependencies] -cloudpickle = "*" -numpy = "*" -packaging = "*" -tensordict = ">=0.4.0" -torch = ">=2.1.0" - -[package.extras] -all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"] -atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"] -checkpointing = ["torchsnapshot"] -dm-control = ["dm_control"] -gym-continuous = ["gymnasium", "mujoco"] -marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"] -offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"] -rendering = ["moviepy"] -tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"] -utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"] - -[package.source] -type = "git" -url = "https://github.com/pytorch/rl" -reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37" -resolved_reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37" - [[package]] name = "torchvision" version = "0.17.1+cpu" @@ -3327,7 +3321,12 @@ files = [ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +[extras] +aloha = ["gym-aloha"] +pusht = ["gym-pusht"] +xarm = ["gym-xarm"] + [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "8800bb8b24312d17b765cd2ce2799f49436171dd5fbf1bec3b07f853cfa9befd" +content-hash = "8fa6dfc30e605741c24f5de58b89125d5b02153f550e5af7a44356956d6bb167" diff --git a/.github/poetry/cpu/pyproject.toml b/.github/poetry/cpu/pyproject.toml index e84b93c9..f5c439dc 100644 --- a/.github/poetry/cpu/pyproject.toml +++ b/.github/poetry/cpu/pyproject.toml @@ -23,7 +23,6 @@ packages = [{include = "lerobot"}] python = "^3.10" termcolor = "^2.4.0" omegaconf = "^2.3.0" -dm-env = "^1.6" pandas = "^2.2.1" wandb = "^0.16.3" moviepy = "^1.0.3" @@ -34,30 +33,41 @@ einops = "^0.7.0" pygame = "^2.5.2" pymunk = "^6.6.0" zarr = "^2.17.0" -shapely = "^2.0.3" -scikit-image = "^0.22.0" numba = "^0.59.0" mpmath = "^1.3.0" torch = {version = "^2.2.1", source = "torch-cpu"} -tensordict = {git = "https://github.com/pytorch/tensordict"} -torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"} -mujoco = "^2.3.7" opencv-python = "^4.9.0.80" diffusers = "^0.26.3" torchvision = {version = "^0.17.1", source = "torch-cpu"} h5py = "^3.10.0" -dm = "^1.3" -dm-control = "1.0.14" robomimic = "0.2.0" huggingface-hub = "^0.21.4" -gymnasium-robotics = "^1.2.4" gymnasium = "^0.29.1" cmake = "^3.29.0.1" +gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true} +gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true} +gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true} +# gym-pusht = { path = "../gym-pusht", develop = true, optional = true} +# gym-xarm = { path = "../gym-xarm", develop = true, optional = true} +# gym-aloha = { path = "../gym-aloha", develop = true, optional = true} + + +[tool.poetry.extras] +pusht = ["gym-pusht"] +xarm = ["gym-xarm"] +aloha = ["gym-aloha"] + + +[tool.poetry.group.dev] +optional = true [tool.poetry.group.dev.dependencies] pre-commit = "^3.6.2" debugpy = "^1.8.1" + + +[tool.poetry.group.test.dependencies] pytest = "^8.1.0" pytest-cov = "^5.0.0" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 478be771..b3411e11 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -34,6 +34,11 @@ jobs: with: python-version: '3.10' + - name: Add SSH key for installing envs + uses: webfactory/ssh-agent@v0.9.0 + with: + ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} + #---------------------------------------------- # install & configure poetry #---------------------------------------------- @@ -87,7 +92,7 @@ jobs: TMP: ~/tmp run: | mkdir ~/tmp - poetry install --no-interaction --no-root + poetry install --no-interaction --no-root --all-extras - name: Save cached venv if: | @@ -106,7 +111,7 @@ jobs: # install project #---------------------------------------------- - name: Install project - run: poetry install --no-interaction + run: poetry install --no-interaction --all-extras #---------------------------------------------- # run tests & coverage @@ -137,6 +142,7 @@ jobs: wandb.enable=False \ offline_steps=2 \ online_steps=0 \ + eval_episodes=1 \ device=cpu \ save_model=true \ save_freq=2 \ @@ -154,17 +160,6 @@ jobs: device=cpu \ policy.pretrained_model_path=tests/outputs/act/models/2.pt - # TODO(aliberts): This takes ~2mn to run, needs to be improved - # - name: Test eval ACT on ALOHA end-to-end (policy is None) - # run: | - # source .venv/bin/activate - # python lerobot/scripts/eval.py \ - # --config lerobot/configs/default.yaml \ - # policy=act \ - # env=aloha \ - # eval_episodes=1 \ - # device=cpu - - name: Test train Diffusion on PushT end-to-end run: | source .venv/bin/activate @@ -174,9 +169,11 @@ jobs: wandb.enable=False \ offline_steps=2 \ online_steps=0 \ + eval_episodes=1 \ device=cpu \ save_model=true \ save_freq=2 \ + policy.batch_size=2 \ hydra.run.dir=tests/outputs/diffusion/ - name: Test eval Diffusion on PushT end-to-end @@ -189,28 +186,20 @@ jobs: device=cpu \ policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt - - name: Test eval Diffusion on PushT end-to-end (policy is None) - run: | - source .venv/bin/activate - python lerobot/scripts/eval.py \ - --config lerobot/configs/default.yaml \ - policy=diffusion \ - env=pusht \ - eval_episodes=1 \ - device=cpu - - name: Test train TDMPC on Simxarm end-to-end run: | source .venv/bin/activate python lerobot/scripts/train.py \ policy=tdmpc \ - env=simxarm \ + env=xarm \ wandb.enable=False \ offline_steps=1 \ online_steps=1 \ + eval_episodes=1 \ device=cpu \ save_model=true \ save_freq=2 \ + policy.batch_size=2 \ hydra.run.dir=tests/outputs/tdmpc/ - name: Test eval TDMPC on Simxarm end-to-end @@ -222,13 +211,3 @@ jobs: env.episode_length=8 \ device=cpu \ policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt - - - name: Test eval TDPMC on Simxarm end-to-end (policy is None) - run: | - source .venv/bin/activate - python lerobot/scripts/eval.py \ - --config lerobot/configs/default.yaml \ - policy=tdmpc \ - env=simxarm \ - eval_episodes=1 \ - device=cpu diff --git a/.gitignore b/.gitignore index ad9892d4..3132aba0 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,9 @@ rl nautilus/*.yaml *.key +# Slurm +sbatch*.sh + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/README.md b/README.md index 31fdde0a..25b8d1e4 100644 --- a/README.md +++ b/README.md @@ -62,21 +62,29 @@ Download our source code: ```bash -git clone https://github.com/huggingface/lerobot.git -cd lerobot +git clone https://github.com/huggingface/lerobot.git && cd lerobot ``` Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html): ```bash -conda create -y -n lerobot python=3.10 -conda activate lerobot +conda create -y -n lerobot python=3.10 && conda activate lerobot ``` -Then, install 🤗 LeRobot: +Install 🤗 LeRobot: ```bash python -m pip install . ``` +For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras: +- [aloha](https://github.com/huggingface/gym-aloha) +- [xarm](https://github.com/huggingface/gym-xarm) +- [pusht](https://github.com/huggingface/gym-pusht) + +For instance, to install 🤗 LeRobot with aloha and pusht, use: +```bash +python -m pip install ".[aloha, pusht]" +``` + To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with ```bash wandb login @@ -89,11 +97,11 @@ wandb login ├── lerobot | ├── configs # contains hydra yaml files with all options that you can override in the command line | | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy -| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, simxarm.yaml +| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, xarm.yaml | | └── policy # various policies: act.yaml, diffusion.yaml, tdmpc.yaml | ├── common # contains classes and utilities -| | ├── datasets # various datasets of human demonstrations: aloha, pusht, simxarm -| | ├── envs # various sim environments: aloha, pusht, simxarm +| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm +| | ├── envs # various sim environments: aloha, pusht, xarm | | └── policies # various policies: act, diffusion, tdmpc | └── scripts # contains functions to execute via command line | ├── visualize_dataset.py # load a dataset and render its demonstrations @@ -112,34 +120,32 @@ wandb login You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities: ```python """ Copy pasted from `examples/1_visualize_dataset.py` """ +import os +from pathlib import Path + import lerobot from lerobot.common.datasets.aloha import AlohaDataset -from torchrl.data.replay_buffers import SamplerWithoutReplacement from lerobot.scripts.visualize_dataset import render_dataset print(lerobot.available_datasets) # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium'] -# we use this sampler to sample 1 frame after the other -sampler = SamplerWithoutReplacement(shuffle=False) - -dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler) +# TODO(rcadene): remove DATA_DIR +dataset = AlohaDataset("pusht", root=Path(os.environ.get("DATA_DIR"))) video_paths = render_dataset( dataset, out_dir="outputs/visualize_dataset/example", - max_num_samples=300, - fps=50, + max_num_episodes=1, ) print(video_paths) -# >>> ['outputs/visualize_dataset/example/episode_0.mp4'] +# ['outputs/visualize_dataset/example/episode_0.mp4'] ``` Or you can achieve the same result by executing our script from the command line: ```bash python lerobot/scripts/visualize_dataset.py \ -env=aloha \ -task=sim_sim_transfer_cube_human \ +env=pusht \ hydra.run.dir=outputs/visualize_dataset/example # >>> ['outputs/visualize_dataset/example/episode_0.mp4'] ``` @@ -198,21 +204,33 @@ pre-commit install pre-commit ``` -### Add dependencies +### Dependencies Instead of using `pip` directly, we use `poetry` for development purposes to easily track our dependencies. If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it. -Install the project with: +Install the project with dev dependencies and all environments: ```bash -poetry install +poetry install --sync --with dev --all-extras +``` +This command should be run when pulling code with and updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the dependencies. + +To selectively install environments (for example aloha and pusht) use: +```bash +poetry install --sync --with dev --extras "aloha pusht" ``` -Then, the equivalent of `pip install some-package`, would just be: +The equivalent of `pip install some-package`, would just be: ```bash poetry add some-package ``` +When changes are made to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies. +```bash +poetry lock --no-update +``` + + **NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated to the CI. To do this, you should create a separate environment and add the new package there as well. For example: ```bash # Add the new package to your main poetry env diff --git a/examples/1_visualize_dataset.py b/examples/1_visualize_dataset.py index f52ab76a..15e0e54d 100644 --- a/examples/1_visualize_dataset.py +++ b/examples/1_visualize_dataset.py @@ -1,24 +1,20 @@ import os - -from torchrl.data.replay_buffers import SamplerWithoutReplacement +from pathlib import Path import lerobot -from lerobot.common.datasets.aloha import AlohaDataset +from lerobot.common.datasets.pusht import PushtDataset from lerobot.scripts.visualize_dataset import render_dataset print(lerobot.available_datasets) # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium'] -# we use this sampler to sample 1 frame after the other -sampler = SamplerWithoutReplacement(shuffle=False) - -dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler, root=os.environ.get("DATA_DIR")) +# TODO(rcadene): remove DATA_DIR +dataset = PushtDataset("pusht", root=Path(os.environ.get("DATA_DIR"))) video_paths = render_dataset( dataset, out_dir="outputs/visualize_dataset/example", - max_num_samples=300, - fps=50, + max_num_episodes=1, ) print(video_paths) # ['outputs/visualize_dataset/example/episode_0.mp4'] diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 01a4cf76..238f953d 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -9,9 +9,8 @@ from pathlib import Path import torch from omegaconf import OmegaConf -from tqdm import trange -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.policies.diffusion.policy import DiffusionPolicy from lerobot.common.utils import init_hydra_config @@ -37,19 +36,33 @@ policy = DiffusionPolicy( cfg_obs_encoder=cfg.obs_encoder, cfg_optimizer=cfg.optimizer, cfg_ema=cfg.ema, - n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, **cfg.policy, ) policy.train() -offline_buffer = make_offline_buffer(cfg) +dataset = make_dataset(cfg) + +# create dataloader for offline training +dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=cfg.policy.batch_size, + shuffle=True, + pin_memory=cfg.device != "cpu", + drop_last=True, +) + +for step, batch in enumerate(dataloader): + info = policy(batch, step) + + if step % cfg.log_freq == 0: + num_samples = (step + 1) * cfg.policy.batch_size + loss = info["loss"] + update_s = info["update_s"] + print(f"step:{step} samples:{num_samples} loss:{loss:.3f} update_time:{update_s:.3f}(seconds)") -for offline_step in trange(cfg.offline_steps): - train_info = policy.update(offline_buffer, offline_step) - if offline_step % cfg.log_freq == 0: - print(train_info) # Save the policy, configuration, and normalization stats for later use. policy.save(output_directory / "model.pt") OmegaConf.save(cfg, output_directory / "config.yaml") -torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth") +torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth") diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 5cf8bdb8..8ab95df8 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -12,14 +12,11 @@ Example: print(lerobot.available_policies) ``` -Note: - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - 1. set the required class attributes: - - for classes inheriting from `AbstractDataset`: `available_datasets` - - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - - for classes inheriting from `AbstractPolicy`: `name` - 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) - 3. update variables in `tests/test_available.py` by importing your new class +When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps: +- Set the required class attributes: `available_datasets`. +- Set the required class attributes: `name`. +- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) +- Update variables in `tests/test_available.py` by importing your new class """ from lerobot.__version__ import __version__ # noqa: F401 @@ -27,16 +24,16 @@ from lerobot.__version__ import __version__ # noqa: F401 available_envs = [ "aloha", "pusht", - "simxarm", + "xarm", ] available_tasks_per_env = { "aloha": [ - "sim_insertion", - "sim_transfer_cube", + "AlohaInsertion-v0", + "AlohaTransferCube-v0", ], - "pusht": ["pusht"], - "simxarm": ["lift"], + "pusht": ["PushT-v0"], + "xarm": ["XarmLift-v0"], } available_datasets_per_env = { @@ -47,7 +44,7 @@ available_datasets_per_env = { "aloha_sim_transfer_cube_scripted", ], "pusht": ["pusht"], - "simxarm": ["xarm_lift_medium"], + "xarm": ["xarm_lift_medium"], } available_datasets = [dataset for env in available_envs for dataset in available_datasets_per_env[env]] diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py deleted file mode 100644 index e9e9c610..00000000 --- a/lerobot/common/datasets/abstract.py +++ /dev/null @@ -1,234 +0,0 @@ -import logging -from copy import deepcopy -from math import ceil -from pathlib import Path -from typing import Callable - -import einops -import torch -import torchrl -import tqdm -from huggingface_hub import snapshot_download -from tensordict import TensorDict -from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement -from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id -from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer -from torchrl.envs.transforms.transforms import Compose - -HF_USER = "lerobot" - - -class AbstractDataset(TensorDictReplayBuffer): - """ - AbstractDataset represents a dataset in the context of imitation learning or reinforcement learning. - This class is designed to be subclassed by concrete implementations that specify particular types of datasets. - These implementations can vary based on the source of the data, the environment the data pertains to, - or the specific kind of data manipulation applied. - - Note: - - `TensorDictReplayBuffer` is the base class from which `AbstractDataset` inherits. It provides the foundational - functionality for storing and retrieving `TensorDict`-like data. - - `available_datasets` should be overridden by concrete subclasses to list the specific dataset variants supported. - It is expected that these variants correspond to a HuggingFace dataset on the hub. - For instance, the `AlohaDataset` which inherites from `AbstractDataset` has 4 available dataset variants: - - [aloha_sim_transfer_cube_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted) - - [aloha_sim_insertion_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted) - - [aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) - - [aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) - - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - 1. set the required class attributes: - - for classes inheriting from `AbstractDataset`: `available_datasets` - - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - - for classes inheriting from `AbstractPolicy`: `name` - 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) - 3. update variables in `tests/test_available.py` by importing your new class - """ - - available_datasets: list[str] | None = None - - def __init__( - self, - dataset_id: str, - version: str | None = None, - batch_size: int | None = None, - *, - shuffle: bool = True, - root: Path | None = None, - pin_memory: bool = False, - prefetch: int = None, - sampler: Sampler | None = None, - collate_fn: Callable | None = None, - writer: Writer | None = None, - transform: "torchrl.envs.Transform" = None, - ): - assert ( - self.available_datasets is not None - ), "Subclasses of `AbstractDataset` should set the `available_datasets` class attribute." - assert ( - dataset_id in self.available_datasets - ), f"The provided dataset ({dataset_id}) is not on the list of available datasets {self.available_datasets}." - - self.dataset_id = dataset_id - self.version = version - self.shuffle = shuffle - self.root = root if root is None else Path(root) - - if self.root is not None and self.version is not None: - logging.warning( - f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})." - ) - - storage = self._download_or_load_dataset() - - super().__init__( - storage=storage, - sampler=sampler, - writer=ImmutableDatasetWriter() if writer is None else writer, - collate_fn=_collate_id if collate_fn is None else collate_fn, - pin_memory=pin_memory, - prefetch=prefetch, - batch_size=batch_size, - transform=transform, - ) - - @property - def stats_patterns(self) -> dict: - return { - ("observation", "state"): "b c -> c", - ("observation", "image"): "b c h w -> c 1 1", - ("action",): "b c -> c", - } - - @property - def image_keys(self) -> list: - return [("observation", "image")] - - @property - def num_cameras(self) -> int: - return len(self.image_keys) - - @property - def num_samples(self) -> int: - return len(self) - - @property - def num_episodes(self) -> int: - return len(self._storage._storage["episode"].unique()) - - @property - def transform(self): - return self._transform - - def set_transform(self, transform): - if not isinstance(transform, Compose): - # required since torchrl calls `len(self._transform)` downstream - if isinstance(transform, list): - self._transform = Compose(*transform) - else: - self._transform = Compose(transform) - else: - self._transform = transform - - def compute_or_load_stats(self, batch_size: int = 32) -> TensorDict: - stats_path = self.data_dir / "stats.pth" - if stats_path.exists(): - stats = torch.load(stats_path) - else: - logging.info(f"compute_stats and save to {stats_path}") - stats = self._compute_stats(batch_size) - torch.save(stats, stats_path) - return stats - - def _download_or_load_dataset(self) -> torch.StorageBase: - if self.root is None: - self.data_dir = Path( - snapshot_download( - repo_id=f"{HF_USER}/{self.dataset_id}", repo_type="dataset", revision=self.version - ) - ) - else: - self.data_dir = self.root / self.dataset_id - return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer")) - - def _compute_stats(self, batch_size: int = 32): - """Compute dataset statistics including minimum, maximum, mean, and standard deviation. - - TODO(alexander-soare): Add a num_batches argument which essentially allows one to use a subset of the - full dataset (for handling very large datasets). The sampling would then have to be random - (preferably without replacement). Both stats computation loops would ideally sample the same - items. - """ - rb = TensorDictReplayBuffer( - storage=self._storage, - batch_size=32, - prefetch=True, - # Note: Due to be refactored soon. The point is that we should go through the whole dataset. - sampler=SamplerWithoutReplacement(drop_last=False, shuffle=False), - ) - - # mean and std will be computed incrementally while max and min will track the running value. - mean, std, max, min = {}, {}, {}, {} - for key in self.stats_patterns: - mean[key] = torch.tensor(0.0).float() - std[key] = torch.tensor(0.0).float() - max[key] = torch.tensor(-float("inf")).float() - min[key] = torch.tensor(float("inf")).float() - - # Compute mean, min, max. - # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get - # surprises when rerunning the sampler. - first_batch = None - running_item_count = 0 # for online mean computation - for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): - batch = rb.sample() - this_batch_size = batch.batch_size[0] - running_item_count += this_batch_size - if first_batch is None: - first_batch = deepcopy(batch) - for key, pattern in self.stats_patterns.items(): - batch[key] = batch[key].float() - # Numerically stable update step for mean computation. - batch_mean = einops.reduce(batch[key], pattern, "mean") - # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents - # the update step, N is the running item count, B is this batch size, x̄ is the running mean, - # and x is the current batch mean. Some rearrangement is then required to avoid risking - # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields - # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ - mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count - max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) - min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) - - # Compute std. - first_batch_ = None - running_item_count = 0 # for online std computation - for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): - batch = rb.sample() - this_batch_size = batch.batch_size[0] - running_item_count += this_batch_size - # Sanity check to make sure the batches are still in the same order as before. - if first_batch_ is None: - first_batch_ = deepcopy(batch) - for key in self.stats_patterns: - assert torch.equal(first_batch_[key], first_batch[key]) - for key, pattern in self.stats_patterns.items(): - batch[key] = batch[key].float() - # Numerically stable update step for mean computation (where the mean is over squared - # residuals).See notes in the mean computation loop above. - batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") - std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count - - for key in self.stats_patterns: - std[key] = torch.sqrt(std[key]) - - stats = TensorDict({}, batch_size=[]) - for key in self.stats_patterns: - stats[(*key, "mean")] = mean[key] - stats[(*key, "std")] = std[key] - stats[(*key, "max")] = max[key] - stats[(*key, "min")] = min[key] - - if key[0] == "observation": - # use same stats for the next observations - stats[("next", *key)] = stats[key] - return stats diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 031c2cd3..4b241ad8 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -1,26 +1,13 @@ import logging from pathlib import Path -from typing import Callable import einops import gdown import h5py import torch -import torchrl import tqdm -from tensordict import TensorDict -from torchrl.data.replay_buffers.samplers import Sampler -from torchrl.data.replay_buffers.storages import TensorStorage -from torchrl.data.replay_buffers.writers import Writer -from lerobot.common.datasets.abstract import AbstractDataset - -DATASET_IDS = [ - "aloha_sim_insertion_human", - "aloha_sim_insertion_scripted", - "aloha_sim_transfer_cube_human", - "aloha_sim_transfer_cube_scripted", -] +from lerobot.common.datasets.utils import load_data_with_delta_timestamps FOLDER_URLS = { "aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF", @@ -66,7 +53,6 @@ CAMERAS = { def download(data_dir, dataset_id): - assert dataset_id in DATASET_IDS assert dataset_id in FOLDER_URLS assert dataset_id in EP48_URLS assert dataset_id in EP49_URLS @@ -80,51 +66,80 @@ def download(data_dir, dataset_id): gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True) -class AlohaDataset(AbstractDataset): - available_datasets = DATASET_IDS +class AlohaDataset(torch.utils.data.Dataset): + available_datasets = [ + "aloha_sim_insertion_human", + "aloha_sim_insertion_scripted", + "aloha_sim_transfer_cube_human", + "aloha_sim_transfer_cube_scripted", + ] + fps = 50 + image_keys = ["observation.images.top"] def __init__( self, dataset_id: str, version: str | None = "v1.2", - batch_size: int | None = None, - *, - shuffle: bool = True, root: Path | None = None, - pin_memory: bool = False, - prefetch: int = None, - sampler: Sampler | None = None, - collate_fn: Callable | None = None, - writer: Writer | None = None, - transform: "torchrl.envs.Transform" = None, + transform: callable = None, + delta_timestamps: dict[list[float]] | None = None, ): - super().__init__( - dataset_id, - version, - batch_size, - shuffle=shuffle, - root=root, - pin_memory=pin_memory, - prefetch=prefetch, - sampler=sampler, - collate_fn=collate_fn, - writer=writer, - transform=transform, - ) + super().__init__() + self.dataset_id = dataset_id + self.version = version + self.root = root + self.transform = transform + self.delta_timestamps = delta_timestamps + + self.data_dir = self.root / f"{self.dataset_id}" + if (self.data_dir / "data_dict.pth").exists() and ( + self.data_dir / "data_ids_per_episode.pth" + ).exists(): + self.data_dict = torch.load(self.data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth") + else: + self._download_and_preproc_obsolete() + self.data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, self.data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth") @property - def stats_patterns(self) -> dict: - d = { - ("observation", "state"): "b c -> c", - ("action",): "b c -> c", - } - for cam in CAMERAS[self.dataset_id]: - d[("observation", "image", cam)] = "b c h w -> c 1 1" - return d + def num_samples(self) -> int: + return len(self.data_dict["index"]) if "index" in self.data_dict else 0 @property - def image_keys(self) -> list: - return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]] + def num_episodes(self) -> int: + return len(self.data_ids_per_episode) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + item = {} + + # get episode id and timestamp of the sampled frame + current_ts = self.data_dict["timestamp"][idx].item() + episode = self.data_dict["episode"][idx].item() + + for key in self.data_dict: + if self.delta_timestamps is not None and key in self.delta_timestamps: + data, is_pad = load_data_with_delta_timestamps( + self.data_dict, + self.data_ids_per_episode, + self.delta_timestamps, + key, + current_ts, + episode, + ) + item[key] = data + item[f"{key}_is_pad"] = is_pad + else: + item[key] = self.data_dict[key][idx] + + if self.transform is not None: + item = self.transform(item) + + return item def _download_and_preproc_obsolete(self): assert self.root is not None @@ -132,54 +147,61 @@ class AlohaDataset(AbstractDataset): if not raw_dir.is_dir(): download(raw_dir, self.dataset_id) - total_num_frames = 0 + total_frames = 0 logging.info("Compute total number of frames to initialize offline buffer") for ep_id in range(NUM_EPISODES[self.dataset_id]): ep_path = raw_dir / f"episode_{ep_id}.hdf5" with h5py.File(ep_path, "r") as ep: - total_num_frames += ep["/action"].shape[0] - 1 - logging.info(f"{total_num_frames=}") + total_frames += ep["/action"].shape[0] - 1 + logging.info(f"{total_frames=}") - logging.info("Initialize and feed offline buffer") - idxtd = 0 + self.data_ids_per_episode = {} + ep_dicts = [] + + frame_idx = 0 for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])): ep_path = raw_dir / f"episode_{ep_id}.hdf5" with h5py.File(ep_path, "r") as ep: - ep_num_frames = ep["/action"].shape[0] + num_frames = ep["/action"].shape[0] # last step of demonstration is considered done - done = torch.zeros(ep_num_frames, 1, dtype=torch.bool) + done = torch.zeros(num_frames, dtype=torch.bool) done[-1] = True state = torch.from_numpy(ep["/observations/qpos"][:]) action = torch.from_numpy(ep["/action"][:]) - ep_td = TensorDict( - { - ("observation", "state"): state[:-1], - "action": action[:-1], - "episode": torch.tensor([ep_id] * (ep_num_frames - 1)), - "frame_id": torch.arange(0, ep_num_frames - 1, 1), - ("next", "observation", "state"): state[1:], - # TODO: compute reward and success - # ("next", "reward"): reward[1:], - ("next", "done"): done[1:], - # ("next", "success"): success[1:], - }, - batch_size=ep_num_frames - 1, - ) + ep_dict = { + "observation.state": state, + "action": action, + "episode": torch.tensor([ep_id] * num_frames), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / self.fps, + # "next.observation.state": state, + # TODO(rcadene): compute reward and success + # "next.reward": reward[1:], + "next.done": done[1:], + # "next.success": success[1:], + } for cam in CAMERAS[self.dataset_id]: image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) image = einops.rearrange(image, "b h w c -> b c h w").contiguous() - ep_td["observation", "image", cam] = image[:-1] - ep_td["next", "observation", "image", cam] = image[1:] + ep_dict[f"observation.images.{cam}"] = image[:-1] + # ep_dict[f"next.observation.images.{cam}"] = image[1:] - if ep_id == 0: - # hack to initialize tensordict data structure to store episodes - td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}") + assert isinstance(ep_id, int) + self.data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1) + assert len(self.data_ids_per_episode[ep_id]) == num_frames - td_data[idxtd : idxtd + len(ep_td)] = ep_td - idxtd = idxtd + len(ep_td) + ep_dicts.append(ep_dict) - return TensorStorage(td_data.lock_()) + frame_idx += num_frames + + self.data_dict = {} + + keys = ep_dicts[0].keys() + for key in keys: + self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + + self.data_dict["index"] = torch.arange(0, total_frames, 1) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 04077034..4ae161f6 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -3,8 +3,9 @@ import os from pathlib import Path import torch -from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler +from torchvision.transforms import v2 +from lerobot.common.datasets.utils import compute_stats from lerobot.common.transforms import NormalizeTransform, Prod # DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and @@ -13,61 +14,16 @@ from lerobot.common.transforms import NormalizeTransform, Prod DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None -def make_offline_buffer( +def make_dataset( cfg, - overwrite_sampler=None, # set normalize=False to remove all transformations and keep images unnormalized in [0,255] normalize=True, - overwrite_batch_size=None, - overwrite_prefetch=None, stats_path=None, ): - if cfg.policy.balanced_sampling: - assert cfg.online_steps > 0 - batch_size = None - pin_memory = False - prefetch = None - else: - assert cfg.online_steps == 0 - num_slices = cfg.policy.batch_size - batch_size = cfg.policy.horizon * num_slices - pin_memory = cfg.device == "cuda" - prefetch = cfg.prefetch + if cfg.env.name == "xarm": + from lerobot.common.datasets.xarm import XarmDataset - if overwrite_batch_size is not None: - batch_size = overwrite_batch_size - - if overwrite_prefetch is not None: - prefetch = overwrite_prefetch - - if overwrite_sampler is None: - # TODO(rcadene): move batch_size outside - num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon - # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. - # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. - - if cfg.offline_prioritized_sampler: - logging.info("use prioritized sampler for offline dataset") - sampler = PrioritizedSliceSampler( - max_capacity=100_000, - alpha=cfg.policy.per_alpha, - beta=cfg.policy.per_beta, - num_slices=num_traj_per_batch, - strict_length=False, - ) - else: - logging.info("use simple sampler for offline dataset") - sampler = SliceSampler( - num_slices=num_traj_per_batch, - strict_length=False, - ) - else: - sampler = overwrite_sampler - - if cfg.env.name == "simxarm": - from lerobot.common.datasets.simxarm import SimxarmDataset - - clsfunc = SimxarmDataset + clsfunc = XarmDataset elif cfg.env.name == "pusht": from lerobot.common.datasets.pusht import PushtDataset @@ -81,56 +37,66 @@ def make_offline_buffer( else: raise ValueError(cfg.env.name) - offline_buffer = clsfunc( - dataset_id=cfg.dataset_id, - sampler=sampler, - batch_size=batch_size, - root=DATA_DIR, - pin_memory=pin_memory, - prefetch=prefetch if isinstance(prefetch, int) else None, - ) - - if cfg.policy.name == "tdmpc": - img_keys = [] - for key in offline_buffer.image_keys: - img_keys.append(("next", *key)) - img_keys += offline_buffer.image_keys - else: - img_keys = offline_buffer.image_keys - + transforms = None if normalize: - transforms = [Prod(in_keys=img_keys, prod=1 / 255)] - # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, # min_max_from_spec - stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path) - - # we only normalize the state and action, since the images are usually normalized inside the model for - # now (except for tdmpc: see the following) - in_keys = [("observation", "state"), ("action")] - - if cfg.policy.name == "tdmpc": - # TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now - in_keys += img_keys - # TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now. - in_keys += [("next", *key) for key in img_keys] - in_keys.append(("next", "observation", "state")) - - if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": - # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this - stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) - stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) - stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) - stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) - # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" - transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) - offline_buffer.set_transform(transforms) + if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": + stats = {} + # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this + stats["observation.state"] = {} + stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) + stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) + stats["action"] = {} + stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) + stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) + elif stats_path is None: + # instantiate a one frame dataset with light transform + stats_dataset = clsfunc( + dataset_id=cfg.dataset_id, + root=DATA_DIR, + transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), + ) - if not overwrite_sampler: - index = torch.arange(0, offline_buffer.num_samples, 1) - sampler.extend(index) + # load stats if the file exists already or compute stats and save it + precomputed_stats_path = stats_dataset.data_dir / "stats.pth" + if precomputed_stats_path.exists(): + stats = torch.load(precomputed_stats_path) + else: + logging.info(f"compute_stats and save to {precomputed_stats_path}") + stats = compute_stats(stats_dataset) + torch.save(stats, stats_path) + else: + stats = torch.load(stats_path) - return offline_buffer + transforms = v2.Compose( + [ + Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), + NormalizeTransform( + stats, + in_keys=[ + "observation.state", + "action", + ], + mode=normalization_mode, + ), + ] + ) + + delta_timestamps = cfg.policy.get("delta_timestamps") + if delta_timestamps is not None: + for key in delta_timestamps: + if isinstance(delta_timestamps[key], str): + delta_timestamps[key] = eval(delta_timestamps[key]) + + dataset = clsfunc( + dataset_id=cfg.dataset_id, + root=DATA_DIR, + delta_timestamps=delta_timestamps, + transform=transforms, + ) + + return dataset diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 624fb140..34d92daa 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -1,21 +1,11 @@ from pathlib import Path -from typing import Callable import einops import numpy as np -import pygame -import pymunk import torch -import torchrl import tqdm -from tensordict import TensorDict -from torchrl.data.replay_buffers.samplers import Sampler -from torchrl.data.replay_buffers.storages import TensorStorage -from torchrl.data.replay_buffers.writers import Writer -from lerobot.common.datasets.abstract import AbstractDataset -from lerobot.common.datasets.utils import download_and_extract_zip -from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely +from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer # as define in env @@ -25,97 +15,93 @@ PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip" PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr") -def get_goal_pose_body(pose): - mass = 1 - inertia = pymunk.moment_for_box(mass, (50, 100)) - body = pymunk.Body(mass, inertia) - # preserving the legacy assignment order for compatibility - # the order here doesn't matter somehow, maybe because CoM is aligned with body origin - body.position = pose[:2].tolist() - body.angle = pose[2] - return body +class PushtDataset(torch.utils.data.Dataset): + """ + Arguments + ---------- + delta_timestamps : dict[list[float]] | None, optional + Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image) + If `None`, no shift is applied to current timestamp and the data from the current frame is loaded. + """ -def add_segment(space, a, b, radius): - shape = pymunk.Segment(space.static_body, a, b, radius) - shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names - return shape - - -def add_tee( - space, - position, - angle, - scale=30, - color="LightSlateGray", - mask=None, -): - if mask is None: - mask = pymunk.ShapeFilter.ALL_MASKS() - mass = 1 - length = 4 - vertices1 = [ - (-length * scale / 2, scale), - (length * scale / 2, scale), - (length * scale / 2, 0), - (-length * scale / 2, 0), - ] - inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1) - vertices2 = [ - (-scale / 2, scale), - (-scale / 2, length * scale), - (scale / 2, length * scale), - (scale / 2, scale), - ] - inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1) - body = pymunk.Body(mass, inertia1 + inertia2) - shape1 = pymunk.Poly(body, vertices1) - shape2 = pymunk.Poly(body, vertices2) - shape1.color = pygame.Color(color) - shape2.color = pygame.Color(color) - shape1.filter = pymunk.ShapeFilter(mask=mask) - shape2.filter = pymunk.ShapeFilter(mask=mask) - body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2 - body.position = position - body.angle = angle - body.friction = 1 - space.add(body, shape1, shape2) - return body - - -class PushtDataset(AbstractDataset): available_datasets = ["pusht"] + fps = 10 + image_keys = ["observation.image"] def __init__( self, dataset_id: str, version: str | None = "v1.2", - batch_size: int | None = None, - *, - shuffle: bool = True, root: Path | None = None, - pin_memory: bool = False, - prefetch: int = None, - sampler: Sampler | None = None, - collate_fn: Callable | None = None, - writer: Writer | None = None, - transform: "torchrl.envs.Transform" = None, + transform: callable = None, + delta_timestamps: dict[list[float]] | None = None, ): - super().__init__( - dataset_id, - version, - batch_size, - shuffle=shuffle, - root=root, - pin_memory=pin_memory, - prefetch=prefetch, - sampler=sampler, - collate_fn=collate_fn, - writer=writer, - transform=transform, - ) + super().__init__() + self.dataset_id = dataset_id + self.version = version + self.root = root + self.transform = transform + self.delta_timestamps = delta_timestamps + + self.data_dir = self.root / f"{self.dataset_id}" + if (self.data_dir / "data_dict.pth").exists() and ( + self.data_dir / "data_ids_per_episode.pth" + ).exists(): + self.data_dict = torch.load(self.data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth") + else: + self._download_and_preproc_obsolete() + self.data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, self.data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth") + + @property + def num_samples(self) -> int: + return len(self.data_dict["index"]) if "index" in self.data_dict else 0 + + @property + def num_episodes(self) -> int: + return len(self.data_ids_per_episode) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + item = {} + + # get episode id and timestamp of the sampled frame + current_ts = self.data_dict["timestamp"][idx].item() + episode = self.data_dict["episode"][idx].item() + + for key in self.data_dict: + if self.delta_timestamps is not None and key in self.delta_timestamps: + data, is_pad = load_data_with_delta_timestamps( + self.data_dict, + self.data_ids_per_episode, + self.delta_timestamps, + key, + current_ts, + episode, + ) + item[key] = data + item[f"{key}_is_pad"] = is_pad + else: + item[key] = self.data_dict[key][idx] + + if self.transform is not None: + item = self.transform(item) + + return item def _download_and_preproc_obsolete(self): + try: + import pymunk + from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely + except ModuleNotFoundError as e: + print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`") + raise e + assert self.root is not None raw_dir = self.root / f"{self.dataset_id}_raw" zarr_path = (raw_dir / PUSHT_ZARR).resolve() @@ -140,34 +126,37 @@ class PushtDataset(AbstractDataset): # TODO: verify that goal pose is expected to be fixed goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) - goal_body = get_goal_pose_body(goal_pos_angle) + goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle) imgs = torch.from_numpy(dataset_dict["img"]) imgs = einops.rearrange(imgs, "b h w c -> b c h w") states = torch.from_numpy(dataset_dict["state"]) actions = torch.from_numpy(dataset_dict["action"]) + self.data_ids_per_episode = {} + ep_dicts = [] + idx0 = 0 - idxtd = 0 for episode_id in tqdm.tqdm(range(num_episodes)): idx1 = dataset_dict.meta["episode_ends"][episode_id] - # to create test artifact - # idx1 = 51 num_frames = idx1 - idx0 assert (episode_ids[idx0:idx1] == episode_id).all() image = imgs[idx0:idx1] + assert image.min() >= 0.0 + assert image.max() <= 255.0 + image = image.type(torch.uint8) state = states[idx0:idx1] agent_pos = state[:, :2] block_pos = state[:, 2:4] block_angle = state[:, 4] - reward = torch.zeros(num_frames, 1) - success = torch.zeros(num_frames, 1, dtype=torch.bool) - done = torch.zeros(num_frames, 1, dtype=torch.bool) + reward = torch.zeros(num_frames) + success = torch.zeros(num_frames, dtype=torch.bool) + done = torch.zeros(num_frames, dtype=torch.bool) for i in range(num_frames): space = pymunk.Space() space.gravity = 0, 0 @@ -175,14 +164,14 @@ class PushtDataset(AbstractDataset): # Add walls. walls = [ - add_segment(space, (5, 506), (5, 5), 2), - add_segment(space, (5, 5), (506, 5), 2), - add_segment(space, (506, 5), (506, 506), 2), - add_segment(space, (5, 506), (506, 506), 2), + PushTEnv.add_segment(space, (5, 506), (5, 5), 2), + PushTEnv.add_segment(space, (5, 5), (506, 5), 2), + PushTEnv.add_segment(space, (506, 5), (506, 506), 2), + PushTEnv.add_segment(space, (5, 506), (506, 506), 2), ] space.add(*walls) - block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item()) + block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item()) goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) block_geom = pymunk_to_shapely(block_body, block_body.shapes) intersection_area = goal_geom.intersection(block_geom).area @@ -194,30 +183,32 @@ class PushtDataset(AbstractDataset): # last step of demonstration is considered done done[-1] = True - ep_td = TensorDict( - { - ("observation", "image"): image[:-1], - ("observation", "state"): agent_pos[:-1], - "action": actions[idx0:idx1][:-1], - "episode": episode_ids[idx0:idx1][:-1], - "frame_id": torch.arange(0, num_frames - 1, 1), - ("next", "observation", "image"): image[1:], - ("next", "observation", "state"): agent_pos[1:], - # TODO: verify that reward and done are aligned with image and agent_pos - ("next", "reward"): reward[1:], - ("next", "done"): done[1:], - ("next", "success"): success[1:], - }, - batch_size=num_frames - 1, - ) + ep_dict = { + "observation.image": image, + "observation.state": agent_pos, + "action": actions[idx0:idx1], + "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / self.fps, + # "next.observation.image": image[1:], + # "next.observation.state": agent_pos[1:], + # TODO(rcadene): verify that reward and done are aligned with image and agent_pos + "next.reward": torch.cat([reward[1:], reward[[-1]]]), + "next.done": torch.cat([done[1:], done[[-1]]]), + "next.success": torch.cat([success[1:], success[[-1]]]), + } + ep_dicts.append(ep_dict) - if episode_id == 0: - # hack to initialize tensordict data structure to store episodes - td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}") - - td_data[idxtd : idxtd + len(ep_td)] = ep_td + assert isinstance(episode_id, int) + self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1) + assert len(self.data_ids_per_episode[episode_id]) == num_frames idx0 = idx1 - idxtd = idxtd + len(ep_td) - return TensorStorage(td_data.lock_()) + self.data_dict = {} + + keys = ep_dicts[0].keys() + for key in keys: + self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + + self.data_dict["index"] = torch.arange(0, total_frames, 1) diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py deleted file mode 100644 index dc30e69e..00000000 --- a/lerobot/common/datasets/simxarm.py +++ /dev/null @@ -1,127 +0,0 @@ -import pickle -import zipfile -from pathlib import Path -from typing import Callable - -import torch -import torchrl -import tqdm -from tensordict import TensorDict -from torchrl.data.replay_buffers.samplers import ( - Sampler, -) -from torchrl.data.replay_buffers.storages import TensorStorage -from torchrl.data.replay_buffers.writers import Writer - -from lerobot.common.datasets.abstract import AbstractDataset - - -def download(): - raise NotImplementedError() - import gdown - - url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" - download_path = "data.zip" - gdown.download(url, download_path, quiet=False) - print("Extracting...") - with zipfile.ZipFile(download_path, "r") as zip_f: - for member in zip_f.namelist(): - if member.startswith("data/xarm") and member.endswith(".pkl"): - print(member) - zip_f.extract(member=member) - Path(download_path).unlink() - - -class SimxarmDataset(AbstractDataset): - available_datasets = [ - "xarm_lift_medium", - ] - - def __init__( - self, - dataset_id: str, - version: str | None = "v1.1", - batch_size: int | None = None, - *, - shuffle: bool = True, - root: Path | None = None, - pin_memory: bool = False, - prefetch: int = None, - sampler: Sampler | None = None, - collate_fn: Callable | None = None, - writer: Writer | None = None, - transform: "torchrl.envs.Transform" = None, - ): - super().__init__( - dataset_id, - version, - batch_size, - shuffle=shuffle, - root=root, - pin_memory=pin_memory, - prefetch=prefetch, - sampler=sampler, - collate_fn=collate_fn, - writer=writer, - transform=transform, - ) - - def _download_and_preproc_obsolete(self): - # assert self.root is not None - # TODO(rcadene): finish download - # download() - - dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl" - print(f"Using offline dataset '{dataset_path}'") - with open(dataset_path, "rb") as f: - dataset_dict = pickle.load(f) - - total_frames = dataset_dict["actions"].shape[0] - - idx0 = 0 - idx1 = 0 - episode_id = 0 - for i in tqdm.tqdm(range(total_frames)): - idx1 += 1 - - if not dataset_dict["dones"][i]: - continue - - num_frames = idx1 - idx0 - - image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) - state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) - next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1]) - next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1]) - next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) - next_done = torch.tensor(dataset_dict["dones"][idx0:idx1]) - - episode = TensorDict( - { - ("observation", "image"): image, - ("observation", "state"): state, - "action": torch.tensor(dataset_dict["actions"][idx0:idx1]), - "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), - "frame_id": torch.arange(0, num_frames, 1), - ("next", "observation", "image"): next_image, - ("next", "observation", "state"): next_state, - ("next", "reward"): next_reward, - ("next", "done"): next_done, - }, - batch_size=num_frames, - ) - - if episode_id == 0: - # hack to initialize tensordict data structure to store episodes - td_data = ( - episode[0] - .expand(total_frames) - .memmap_like(self.root / f"{self.dataset_id}" / "replay_buffer") - ) - - td_data[idx0:idx1] = episode - - episode_id += 1 - idx0 = idx1 - - return TensorStorage(td_data.lock_()) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 0ad43a65..e67d8a04 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -1,8 +1,12 @@ import io import zipfile +from copy import deepcopy +from math import ceil from pathlib import Path +import einops import requests +import torch import tqdm @@ -28,3 +32,185 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool: return True else: return False + + +def load_data_with_delta_timestamps( + data_dict: dict[torch.Tensor], + data_ids_per_episode: dict[torch.Tensor], + delta_timestamps: list[float], + key: str, + current_ts: float, + episode: int, + tol: float = 0.04, +): + """ + Given a current timestamp (e.g. current_ts=0.6) and a list of timestamps differences (e.g. delta_timestamps=[-0.8, -0.2, 0, 0.2]), + this function compute the query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames of the specified modality (e.g. key="observation.image"). + + Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError. + When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp, + the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range. + For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode, + or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode. + + Parameters: + - data_dict (dict): A dictionary containing the data, where each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). + - data_ids_per_episode (dict): A dictionary where keys are episode identifiers and values are lists of indices corresponding to frames associated with each episode. + - delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible key to be retrieved. These deltas are added to the current_ts to form the query timestamps. + - key (str): The key specifying which data modality is to be retrieved from the data_dict. + - current_ts (float): The current timestamp to which the delta timestamps are added to form the query timestamps. + - episode (int): The identifier of the episode from which frames are to be retrieved. + - tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04. + + Returns: + - tuple: A tuple containing two elements: + - The first element is the data retrieved from the specified modality based on the closest match to the query timestamps. + - The second element is a boolean array indicating which frames were considered as padding (True if the distance to the closest timestamp was greater than the tolerance level). + + Raises: + - AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection. + """ + # get indices of the frames associated to the episode, and their timestamps + ep_data_ids = data_ids_per_episode[episode] + ep_timestamps = data_dict["timestamp"][ep_data_ids] + + # we make the assumption that the timestamps are sorted + ep_first_ts = ep_timestamps[0] + ep_last_ts = ep_timestamps[-1] + + # get timestamps used as query to retrieve data of previous/future frames + delta_ts = delta_timestamps[key] + query_ts = current_ts + torch.tensor(delta_ts) + + # compute distances between each query timestamp and all timestamps of all the frames belonging to the episode + dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1) + min_, argmin_ = dist.min(1) + + # get the indices of the data that are closest to the query timestamps + data_ids = ep_data_ids[argmin_] + # closest_ts = ep_timestamps[argmin_] + + # get the data + data = data_dict[key][data_ids].clone() + + # TODO(rcadene): synchronize timestamps + interpolation if needed + + is_pad = min_ > tol + + # check violated query timestamps are all outside the episode range + assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), ( + f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range." + "This might be due to synchronization issues with timestamps during data collection." + ) + + return data, is_pad + + +def get_stats_einops_patterns(dataset): + """These einops patterns will be used to aggregate batches and compute statistics.""" + stats_patterns = { + "action": "b c -> c", + "observation.state": "b c -> c", + } + for key in dataset.image_keys: + stats_patterns[key] = "b c h w -> c 1 1" + return stats_patterns + + +def compute_stats(dataset, batch_size=32, max_num_samples=None): + if max_num_samples is None: + max_num_samples = len(dataset) + else: + raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.") + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=batch_size, + shuffle=False, + # pin_memory=cfg.device != "cpu", + drop_last=False, + ) + + # get einops patterns to aggregate batches and compute statistics + stats_patterns = get_stats_einops_patterns(dataset) + + # mean and std will be computed incrementally while max and min will track the running value. + mean, std, max, min = {}, {}, {}, {} + for key in stats_patterns: + mean[key] = torch.tensor(0.0).float() + std[key] = torch.tensor(0.0).float() + max[key] = torch.tensor(-float("inf")).float() + min[key] = torch.tensor(float("inf")).float() + + # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get + # surprises when rerunning the sampler. + first_batch = None + running_item_count = 0 # for online mean computation + for i, batch in enumerate( + tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") + ): + this_batch_size = len(batch["index"]) + running_item_count += this_batch_size + if first_batch is None: + first_batch = deepcopy(batch) + for key, pattern in stats_patterns.items(): + batch[key] = batch[key].float() + # Numerically stable update step for mean computation. + batch_mean = einops.reduce(batch[key], pattern, "mean") + # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents + # the update step, N is the running item count, B is this batch size, x̄ is the running mean, + # and x is the current batch mean. Some rearrangement is then required to avoid risking + # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields + # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ + mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count + max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) + min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) + + if i == ceil(max_num_samples / batch_size) - 1: + break + + first_batch_ = None + running_item_count = 0 # for online std computation + for i, batch in enumerate( + tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std") + ): + this_batch_size = len(batch["index"]) + running_item_count += this_batch_size + # Sanity check to make sure the batches are still in the same order as before. + if first_batch_ is None: + first_batch_ = deepcopy(batch) + for key in stats_patterns: + assert torch.equal(first_batch_[key], first_batch[key]) + for key, pattern in stats_patterns.items(): + batch[key] = batch[key].float() + # Numerically stable update step for mean computation (where the mean is over squared + # residuals).See notes in the mean computation loop above. + batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") + std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count + + if i == ceil(max_num_samples / batch_size) - 1: + break + + for key in stats_patterns: + std[key] = torch.sqrt(std[key]) + + stats = {} + for key in stats_patterns: + stats[key] = { + "mean": mean[key], + "std": std[key], + "max": max[key], + "min": min[key], + } + + return stats + + +def cycle(iterable): + iterator = iter(iterable) + while True: + try: + yield next(iterator) + except StopIteration: + iterator = iter(iterable) diff --git a/lerobot/common/datasets/xarm.py b/lerobot/common/datasets/xarm.py new file mode 100644 index 00000000..0dfcc5c9 --- /dev/null +++ b/lerobot/common/datasets/xarm.py @@ -0,0 +1,163 @@ +import pickle +import zipfile +from pathlib import Path + +import torch +import tqdm + +from lerobot.common.datasets.utils import load_data_with_delta_timestamps + + +def download(raw_dir): + import gdown + + raw_dir.mkdir(parents=True, exist_ok=True) + url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" + zip_path = raw_dir / "data.zip" + gdown.download(url, str(zip_path), quiet=False) + print("Extracting...") + with zipfile.ZipFile(str(zip_path), "r") as zip_f: + for member in zip_f.namelist(): + if member.startswith("data/xarm") and member.endswith(".pkl"): + print(member) + zip_f.extract(member=member) + zip_path.unlink() + + +class XarmDataset(torch.utils.data.Dataset): + available_datasets = [ + "xarm_lift_medium", + ] + fps = 15 + image_keys = ["observation.image"] + + def __init__( + self, + dataset_id: str, + version: str | None = "v1.1", + root: Path | None = None, + transform: callable = None, + delta_timestamps: dict[list[float]] | None = None, + ): + super().__init__() + self.dataset_id = dataset_id + self.version = version + self.root = root + self.transform = transform + self.delta_timestamps = delta_timestamps + + self.data_dir = self.root / f"{self.dataset_id}" + if (self.data_dir / "data_dict.pth").exists() and ( + self.data_dir / "data_ids_per_episode.pth" + ).exists(): + self.data_dict = torch.load(self.data_dir / "data_dict.pth") + self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth") + else: + self._download_and_preproc_obsolete() + self.data_dir.mkdir(parents=True, exist_ok=True) + torch.save(self.data_dict, self.data_dir / "data_dict.pth") + torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth") + + @property + def num_samples(self) -> int: + return len(self.data_dict["index"]) if "index" in self.data_dict else 0 + + @property + def num_episodes(self) -> int: + return len(self.data_ids_per_episode) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + item = {} + + # get episode id and timestamp of the sampled frame + current_ts = self.data_dict["timestamp"][idx].item() + episode = self.data_dict["episode"][idx].item() + + for key in self.data_dict: + if self.delta_timestamps is not None and key in self.delta_timestamps: + data, is_pad = load_data_with_delta_timestamps( + self.data_dict, + self.data_ids_per_episode, + self.delta_timestamps, + key, + current_ts, + episode, + ) + item[key] = data + item[f"{key}_is_pad"] = is_pad + else: + item[key] = self.data_dict[key][idx] + + if self.transform is not None: + item = self.transform(item) + + return item + + def _download_and_preproc_obsolete(self): + assert self.root is not None + raw_dir = self.root / f"{self.dataset_id}_raw" + if not raw_dir.exists(): + download(raw_dir) + + dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl" + print(f"Using offline dataset '{dataset_path}'") + with open(dataset_path, "rb") as f: + dataset_dict = pickle.load(f) + + total_frames = dataset_dict["actions"].shape[0] + + self.data_ids_per_episode = {} + ep_dicts = [] + + idx0 = 0 + idx1 = 0 + episode_id = 0 + for i in tqdm.tqdm(range(total_frames)): + idx1 += 1 + + if not dataset_dict["dones"][i]: + continue + + num_frames = idx1 - idx0 + + image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) + state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) + action = torch.tensor(dataset_dict["actions"][idx0:idx1]) + # TODO(rcadene): we have a missing last frame which is the observation when the env is done + # it is critical to have this frame for tdmpc to predict a "done observation/state" + # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1]) + # next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1]) + next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) + next_done = torch.tensor(dataset_dict["dones"][idx0:idx1]) + + ep_dict = { + "observation.image": image, + "observation.state": state, + "action": action, + "episode": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / self.fps, + # "next.observation.image": next_image, + # "next.observation.state": next_state, + "next.reward": next_reward, + "next.done": next_done, + } + ep_dicts.append(ep_dict) + + assert isinstance(episode_id, int) + self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1) + assert len(self.data_ids_per_episode[episode_id]) == num_frames + + idx0 = idx1 + episode_id += 1 + + self.data_dict = {} + + keys = ep_dicts[0].keys() + for key in keys: + self.data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + + self.data_dict["index"] = torch.arange(0, total_frames, 1) diff --git a/lerobot/common/envs/abstract.py b/lerobot/common/envs/abstract.py deleted file mode 100644 index ea5ce3da..00000000 --- a/lerobot/common/envs/abstract.py +++ /dev/null @@ -1,92 +0,0 @@ -from collections import deque -from typing import Optional - -from tensordict import TensorDict -from torchrl.envs import EnvBase - -from lerobot.common.utils import set_global_seed - - -class AbstractEnv(EnvBase): - """ - Note: - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - 1. set the required class attributes: - - for classes inheriting from `AbstractDataset`: `available_datasets` - - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - - for classes inheriting from `AbstractPolicy`: `name` - 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) - 3. update variables in `tests/test_available.py` by importing your new class - """ - - name: str | None = None # same name should be used to instantiate the environment in factory.py - available_tasks: list[str] | None = None # for instance: sim_insertion, sim_transfer_cube, pusht, lift - - def __init__( - self, - task, - frame_skip: int = 1, - from_pixels: bool = False, - pixels_only: bool = False, - image_size=None, - seed=1337, - device="cpu", - num_prev_obs=1, - num_prev_action=0, - ): - super().__init__(device=device, batch_size=[]) - assert self.name is not None, "Subclasses of `AbstractEnv` should set the `name` class attribute." - assert ( - self.available_tasks is not None - ), "Subclasses of `AbstractEnv` should set the `available_tasks` class attribute." - assert ( - task in self.available_tasks - ), f"The provided task ({task}) is not on the list of available tasks {self.available_tasks}." - - self.task = task - self.frame_skip = frame_skip - self.from_pixels = from_pixels - self.pixels_only = pixels_only - self.image_size = image_size - self.num_prev_obs = num_prev_obs - self.num_prev_action = num_prev_action - - if pixels_only: - assert from_pixels - if from_pixels: - assert image_size - - self._make_env() - self._make_spec() - - # self._next_seed will be used for the next reset. It is recommended that when self.set_seed is called - # you store the return value in self._next_seed (it will be a new randomly generated seed). - self._next_seed = seed - # Don't store the result of this in self._next_seed, as we want to make sure that the first time - # self._reset is called, we use seed. - self.set_seed(seed) - - if self.num_prev_obs > 0: - self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs) - self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs) - if self.num_prev_action > 0: - raise NotImplementedError() - # self._prev_action_queue = deque(maxlen=self.num_prev_action) - - def render(self, mode="rgb_array", width=640, height=480): - raise NotImplementedError("Abstract method") - - def _reset(self, tensordict: Optional[TensorDict] = None): - raise NotImplementedError("Abstract method") - - def _step(self, tensordict: TensorDict): - raise NotImplementedError("Abstract method") - - def _make_env(self): - raise NotImplementedError("Abstract method") - - def _make_spec(self): - raise NotImplementedError("Abstract method") - - def _set_seed(self, seed: Optional[int]): - set_global_seed(seed) diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml deleted file mode 100644 index 8002838c..00000000 --- a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml +++ /dev/null @@ -1,59 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml deleted file mode 100644 index 05249ad2..00000000 --- a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml +++ /dev/null @@ -1,48 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml deleted file mode 100644 index 511f7947..00000000 --- a/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml +++ /dev/null @@ -1,53 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml deleted file mode 100644 index 2d85a47c..00000000 --- a/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml +++ /dev/null @@ -1,42 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/lerobot/common/envs/aloha/assets/scene.xml b/lerobot/common/envs/aloha/assets/scene.xml deleted file mode 100644 index 0f61b8a5..00000000 --- a/lerobot/common/envs/aloha/assets/scene.xml +++ /dev/null @@ -1,38 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/lerobot/common/envs/aloha/assets/tabletop.stl b/lerobot/common/envs/aloha/assets/tabletop.stl deleted file mode 100644 index 1c17d3f0..00000000 --- a/lerobot/common/envs/aloha/assets/tabletop.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:76a1571d1aa36520f2bd81c268991b99816c2a7819464d718e0fd9976fe30dce -size 684 diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl deleted file mode 100644 index ef1f3f35..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:df73ae5b9058e5d50a6409ac2ab687dade75053a86591bb5e23ab051dbf2d659 -size 83384 diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl deleted file mode 100644 index 7eb8aefd..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:56fb3cc1236d4193106038adf8e457c7252ae9e86c7cee6dabf0578c53666358 -size 83384 diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl b/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl deleted file mode 100644 index 4c2b3a1f..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a4baacd9a64df1be60ea5e98f50f3c660e1b7a1fe9684aace6004c5058c09483 -size 42884 diff --git a/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl b/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl deleted file mode 100644 index 8a30f7cc..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a18a1601074d29ed1d546ead70cd18fbb063f1db7b5b96b9f0365be714f3136a -size 3884 diff --git a/lerobot/common/envs/aloha/assets/vx300s_1_base.stl b/lerobot/common/envs/aloha/assets/vx300s_1_base.stl deleted file mode 100644 index 9198e625..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_1_base.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d100cafe656671ca8fde98fb6a4cf2d1b746995c51c61c25ad9ea2715635d146 -size 99984 diff --git a/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl b/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl deleted file mode 100644 index ab3d9570..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:139745a74055cb0b23430bb5bc032bf68cf7bea5e4975c8f4c04107ae005f7f0 -size 63884 diff --git a/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl b/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl deleted file mode 100644 index 3d6f663c..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:900f236320dd3d500870c5fde763b2d47502d51e043a5c377875e70237108729 -size 102984 diff --git a/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl b/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl deleted file mode 100644 index 4eb249e7..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4104fc54bbfb8a9b533029f1e7e3ade3d54d638372b3195daa0c98f57e0295b5 -size 49584 diff --git a/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl b/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl deleted file mode 100644 index 34c76221..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:66814e27fa728056416e25e02e89eb7d34c51d51c51e7c3df873829037ddc6b8 -size 99884 diff --git a/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl b/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl deleted file mode 100644 index 232fabf7..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:90eb145c85627968c3776ae6de23ccff7e112c9dd713c46bc9acdfdaa859a048 -size 70784 diff --git a/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl b/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl deleted file mode 100644 index 946c3c86..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:786c1077bfd226f14219581b11d5f19464ca95b17132e0bb7532503568f5af90 -size 450084 diff --git a/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl b/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl deleted file mode 100644 index 28d5bd76..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d1275a93fe2157c83dbc095617fb7e672888bdd48ec070a35ef4ab9ebd9755b0 -size 31684 diff --git a/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl b/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl deleted file mode 100644 index 5201d5ea..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a4de62c9a2ed2c78433010e4c05530a1254b1774a7651967f406120c9bf8973e -size 379484 diff --git a/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml b/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml deleted file mode 100644 index 93037ab7..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml +++ /dev/null @@ -1,17 +0,0 @@ - - - - - - - - - - - - - - - - - diff --git a/lerobot/common/envs/aloha/assets/vx300s_left.xml b/lerobot/common/envs/aloha/assets/vx300s_left.xml deleted file mode 100644 index 3af6c235..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_left.xml +++ /dev/null @@ -1,59 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/lerobot/common/envs/aloha/assets/vx300s_right.xml b/lerobot/common/envs/aloha/assets/vx300s_right.xml deleted file mode 100644 index 495df478..00000000 --- a/lerobot/common/envs/aloha/assets/vx300s_right.xml +++ /dev/null @@ -1,59 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/lerobot/common/envs/aloha/constants.py b/lerobot/common/envs/aloha/constants.py deleted file mode 100644 index e582e5f3..00000000 --- a/lerobot/common/envs/aloha/constants.py +++ /dev/null @@ -1,163 +0,0 @@ -from pathlib import Path - -### Simulation envs fixed constants -DT = 0.02 # 0.02 ms -> 1/0.2 = 50 hz -FPS = 50 - - -JOINTS = [ - # absolute joint position - "left_arm_waist", - "left_arm_shoulder", - "left_arm_elbow", - "left_arm_forearm_roll", - "left_arm_wrist_angle", - "left_arm_wrist_rotate", - # normalized gripper position 0: close, 1: open - "left_arm_gripper", - # absolute joint position - "right_arm_waist", - "right_arm_shoulder", - "right_arm_elbow", - "right_arm_forearm_roll", - "right_arm_wrist_angle", - "right_arm_wrist_rotate", - # normalized gripper position 0: close, 1: open - "right_arm_gripper", -] - -ACTIONS = [ - # position and quaternion for end effector - "left_arm_waist", - "left_arm_shoulder", - "left_arm_elbow", - "left_arm_forearm_roll", - "left_arm_wrist_angle", - "left_arm_wrist_rotate", - # normalized gripper position (0: close, 1: open) - "left_arm_gripper", - "right_arm_waist", - "right_arm_shoulder", - "right_arm_elbow", - "right_arm_forearm_roll", - "right_arm_wrist_angle", - "right_arm_wrist_rotate", - # normalized gripper position (0: close, 1: open) - "right_arm_gripper", -] - - -START_ARM_POSE = [ - 0, - -0.96, - 1.16, - 0, - -0.3, - 0, - 0.02239, - -0.02239, - 0, - -0.96, - 1.16, - 0, - -0.3, - 0, - 0.02239, - -0.02239, -] - -ASSETS_DIR = Path(__file__).parent.resolve() / "assets" # note: absolute path - -# Left finger position limits (qpos[7]), right_finger = -1 * left_finger -MASTER_GRIPPER_POSITION_OPEN = 0.02417 -MASTER_GRIPPER_POSITION_CLOSE = 0.01244 -PUPPET_GRIPPER_POSITION_OPEN = 0.05800 -PUPPET_GRIPPER_POSITION_CLOSE = 0.01844 - -# Gripper joint limits (qpos[6]) -MASTER_GRIPPER_JOINT_OPEN = 0.3083 -MASTER_GRIPPER_JOINT_CLOSE = -0.6842 -PUPPET_GRIPPER_JOINT_OPEN = 1.4910 -PUPPET_GRIPPER_JOINT_CLOSE = -0.6213 - -MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2 - -############################ Helper functions ############################ - - -def normalize_master_gripper_position(x): - return (x - MASTER_GRIPPER_POSITION_CLOSE) / ( - MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE - ) - - -def normalize_puppet_gripper_position(x): - return (x - PUPPET_GRIPPER_POSITION_CLOSE) / ( - PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE - ) - - -def unnormalize_master_gripper_position(x): - return x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE - - -def unnormalize_puppet_gripper_position(x): - return x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE - - -def convert_position_from_master_to_puppet(x): - return unnormalize_puppet_gripper_position(normalize_master_gripper_position(x)) - - -def normalizer_master_gripper_joint(x): - return (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) - - -def normalize_puppet_gripper_joint(x): - return (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) - - -def unnormalize_master_gripper_joint(x): - return x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE - - -def unnormalize_puppet_gripper_joint(x): - return x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE - - -def convert_join_from_master_to_puppet(x): - return unnormalize_puppet_gripper_joint(normalizer_master_gripper_joint(x)) - - -def normalize_master_gripper_velocity(x): - return x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) - - -def normalize_puppet_gripper_velocity(x): - return x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) - - -def convert_master_from_position_to_joint(x): - return ( - normalize_master_gripper_position(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) - + MASTER_GRIPPER_JOINT_CLOSE - ) - - -def convert_master_from_joint_to_position(x): - return unnormalize_master_gripper_position( - (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) - ) - - -def convert_puppet_from_position_to_join(x): - return ( - normalize_puppet_gripper_position(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) - + PUPPET_GRIPPER_JOINT_CLOSE - ) - - -def convert_puppet_from_joint_to_position(x): - return unnormalize_puppet_gripper_position( - (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) - ) diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py deleted file mode 100644 index 8f907650..00000000 --- a/lerobot/common/envs/aloha/env.py +++ /dev/null @@ -1,298 +0,0 @@ -import importlib -import logging -from collections import deque -from typing import Optional - -import einops -import numpy as np -import torch -from dm_control import mujoco -from dm_control.rl import control -from tensordict import TensorDict -from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) - -from lerobot.common.envs.abstract import AbstractEnv -from lerobot.common.envs.aloha.constants import ( - ACTIONS, - ASSETS_DIR, - DT, - JOINTS, -) -from lerobot.common.envs.aloha.tasks.sim import BOX_POSE, InsertionTask, TransferCubeTask -from lerobot.common.envs.aloha.tasks.sim_end_effector import ( - InsertionEndEffectorTask, - TransferCubeEndEffectorTask, -) -from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose -from lerobot.common.utils import set_global_seed - -_has_gym = importlib.util.find_spec("gymnasium") is not None - - -class AlohaEnv(AbstractEnv): - name = "aloha" - available_tasks = ["sim_insertion", "sim_transfer_cube"] - _reset_warning_issued = False - - def __init__( - self, - task, - frame_skip: int = 1, - from_pixels: bool = False, - pixels_only: bool = False, - image_size=None, - seed=1337, - device="cpu", - num_prev_obs=1, - num_prev_action=0, - ): - super().__init__( - task=task, - frame_skip=frame_skip, - from_pixels=from_pixels, - pixels_only=pixels_only, - image_size=image_size, - seed=seed, - device=device, - num_prev_obs=num_prev_obs, - num_prev_action=num_prev_action, - ) - - def _make_env(self): - if not _has_gym: - raise ImportError("Cannot import gymnasium.") - - if not self.from_pixels: - raise NotImplementedError() - - self._env = self._make_env_task(self.task) - - def render(self, mode="rgb_array", width=640, height=480): - # TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close) - image = self._env.physics.render(height=height, width=width, camera_id="top") - return image - - def _make_env_task(self, task_name): - # time limit is controlled by StepCounter in env factory - time_limit = float("inf") - - if "sim_transfer_cube" in task_name: - xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml" - physics = mujoco.Physics.from_xml_path(str(xml_path)) - task = TransferCubeTask(random=False) - elif "sim_insertion" in task_name: - xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml" - physics = mujoco.Physics.from_xml_path(str(xml_path)) - task = InsertionTask(random=False) - elif "sim_end_effector_transfer_cube" in task_name: - raise NotImplementedError() - xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml" - physics = mujoco.Physics.from_xml_path(str(xml_path)) - task = TransferCubeEndEffectorTask(random=False) - elif "sim_end_effector_insertion" in task_name: - raise NotImplementedError() - xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml" - physics = mujoco.Physics.from_xml_path(str(xml_path)) - task = InsertionEndEffectorTask(random=False) - else: - raise NotImplementedError(task_name) - - env = control.Environment( - physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False - ) - return env - - def _format_raw_obs(self, raw_obs): - if self.from_pixels: - image = torch.from_numpy(raw_obs["images"]["top"].copy()) - image = einops.rearrange(image, "h w c -> c h w") - assert image.dtype == torch.uint8 - obs = {"image": {"top": image}} - - if not self.pixels_only: - obs["state"] = torch.from_numpy(raw_obs["qpos"]).type(torch.float32) - else: - # TODO(rcadene): - raise NotImplementedError() - # obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)} - - return obs - - def _reset(self, tensordict: Optional[TensorDict] = None): - if tensordict is not None and not AlohaEnv._reset_warning_issued: - logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") - AlohaEnv._reset_warning_issued = True - - # Seed the environment and update the seed to be used for the next reset. - self._next_seed = self.set_seed(self._next_seed) - - # TODO(rcadene): do not use global variable for this - if "sim_transfer_cube" in self.task: - BOX_POSE[0] = sample_box_pose() # used in sim reset - elif "sim_insertion" in self.task: - BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset - - raw_obs = self._env.reset() - - obs = self._format_raw_obs(raw_obs.observation) - - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue = deque( - [obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} - if "state" in obs: - self._prev_obs_state_queue = deque( - [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs - - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "done": torch.tensor([False], dtype=torch.bool), - }, - batch_size=[], - ) - - return td - - def _step(self, tensordict: TensorDict): - td = tensordict - action = td["action"].numpy() - assert action.ndim == 1 - # TODO(rcadene): add info["is_success"] and info["success"] ? - - _, reward, _, raw_obs = self._env.step(action) - - # TODO(rcadene): add an enum - success = done = reward == 4 - obs = self._format_raw_obs(raw_obs) - - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue.append(obs["image"]["top"]) - stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} - if "state" in obs: - self._prev_obs_state_queue.append(obs["state"]) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs - - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "reward": torch.tensor([reward], dtype=torch.float32), - # success and done are true when coverage > self.success_threshold in env - "done": torch.tensor([done], dtype=torch.bool), - "success": torch.tensor([success], dtype=torch.bool), - }, - batch_size=[], - ) - return td - - def _make_spec(self): - obs = {} - from omegaconf import OmegaConf - - if self.from_pixels: - if isinstance(self.image_size, int): - image_shape = (3, self.image_size, self.image_size) - elif OmegaConf.is_list(self.image_size) or isinstance(self.image_size, list): - assert len(self.image_size) == 3 # c h w - assert self.image_size[0] == 3 # c is RGB - image_shape = tuple(self.image_size) - else: - raise ValueError(self.image_size) - if self.num_prev_obs > 0: - image_shape = (self.num_prev_obs + 1, *image_shape) - - obs["image"] = { - "top": BoundedTensorSpec( - low=0, - high=255, - shape=image_shape, - dtype=torch.uint8, - device=self.device, - ) - } - if not self.pixels_only: - state_shape = (len(JOINTS),) - if self.num_prev_obs > 0: - state_shape = (self.num_prev_obs + 1, *state_shape) - - obs["state"] = UnboundedContinuousTensorSpec( - # TODO: add low and high bounds - shape=state_shape, - dtype=torch.float32, - device=self.device, - ) - else: - # TODO(rcadene): add observation_space achieved_goal and desired_goal? - state_shape = (len(JOINTS),) - if self.num_prev_obs > 0: - state_shape = (self.num_prev_obs + 1, *state_shape) - - obs["state"] = UnboundedContinuousTensorSpec( - # TODO: add low and high bounds - shape=state_shape, - dtype=torch.float32, - device=self.device, - ) - self.observation_spec = CompositeSpec({"observation": obs}) - - # TODO(rcadene): valid when controling end effector? - # action_space = self._env.action_spec() - # self.action_spec = BoundedTensorSpec( - # low=action_space.minimum, - # high=action_space.maximum, - # shape=action_space.shape, - # dtype=torch.float32, - # device=self.device, - # ) - - # TODO(rcaene): add bounds (where are they????) - self.action_spec = BoundedTensorSpec( - shape=(len(ACTIONS)), - low=-1, - high=1, - dtype=torch.float32, - device=self.device, - ) - - self.reward_spec = UnboundedContinuousTensorSpec( - shape=(1,), - dtype=torch.float32, - device=self.device, - ) - - self.done_spec = CompositeSpec( - { - "done": DiscreteTensorSpec( - 2, - shape=(1,), - dtype=torch.bool, - device=self.device, - ), - "success": DiscreteTensorSpec( - 2, - shape=(1,), - dtype=torch.bool, - device=self.device, - ), - } - ) - - def _set_seed(self, seed: Optional[int]): - set_global_seed(seed) - # TODO(rcadene): seed the env - # self._env.seed(seed) - logging.warning("Aloha env is not seeded") diff --git a/lerobot/common/envs/aloha/tasks/sim.py b/lerobot/common/envs/aloha/tasks/sim.py deleted file mode 100644 index ee1d0927..00000000 --- a/lerobot/common/envs/aloha/tasks/sim.py +++ /dev/null @@ -1,219 +0,0 @@ -import collections - -import numpy as np -from dm_control.suite import base - -from lerobot.common.envs.aloha.constants import ( - START_ARM_POSE, - normalize_puppet_gripper_position, - normalize_puppet_gripper_velocity, - unnormalize_puppet_gripper_position, -) - -BOX_POSE = [None] # to be changed from outside - -""" -Environment for simulated robot bi-manual manipulation, with joint position control -Action space: [left_arm_qpos (6), # absolute joint position - left_gripper_positions (1), # normalized gripper position (0: close, 1: open) - right_arm_qpos (6), # absolute joint position - right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) - -Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position - left_gripper_position (1), # normalized gripper position (0: close, 1: open) - right_arm_qpos (6), # absolute joint position - right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) - "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) - left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) - right_arm_qvel (6), # absolute joint velocity (rad) - right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) - "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' -""" - - -class BimanualViperXTask(base.Task): - def __init__(self, random=None): - super().__init__(random=random) - - def before_step(self, action, physics): - left_arm_action = action[:6] - right_arm_action = action[7 : 7 + 6] - normalized_left_gripper_action = action[6] - normalized_right_gripper_action = action[7 + 6] - - left_gripper_action = unnormalize_puppet_gripper_position(normalized_left_gripper_action) - right_gripper_action = unnormalize_puppet_gripper_position(normalized_right_gripper_action) - - full_left_gripper_action = [left_gripper_action, -left_gripper_action] - full_right_gripper_action = [right_gripper_action, -right_gripper_action] - - env_action = np.concatenate( - [left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action] - ) - super().before_step(env_action, physics) - return - - def initialize_episode(self, physics): - """Sets the state of the environment at the start of each episode.""" - super().initialize_episode(physics) - - @staticmethod - def get_qpos(physics): - qpos_raw = physics.data.qpos.copy() - left_qpos_raw = qpos_raw[:8] - right_qpos_raw = qpos_raw[8:16] - left_arm_qpos = left_qpos_raw[:6] - right_arm_qpos = right_qpos_raw[:6] - left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])] - right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])] - return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) - - @staticmethod - def get_qvel(physics): - qvel_raw = physics.data.qvel.copy() - left_qvel_raw = qvel_raw[:8] - right_qvel_raw = qvel_raw[8:16] - left_arm_qvel = left_qvel_raw[:6] - right_arm_qvel = right_qvel_raw[:6] - left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])] - right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])] - return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) - - @staticmethod - def get_env_state(physics): - raise NotImplementedError - - def get_observation(self, physics): - obs = collections.OrderedDict() - obs["qpos"] = self.get_qpos(physics) - obs["qvel"] = self.get_qvel(physics) - obs["env_state"] = self.get_env_state(physics) - obs["images"] = {} - obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top") - obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle") - obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close") - - return obs - - def get_reward(self, physics): - # return whether left gripper is holding the box - raise NotImplementedError - - -class TransferCubeTask(BimanualViperXTask): - def __init__(self, random=None): - super().__init__(random=random) - self.max_reward = 4 - - def initialize_episode(self, physics): - """Sets the state of the environment at the start of each episode.""" - # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside - # reset qpos, control and box position - with physics.reset_context(): - physics.named.data.qpos[:16] = START_ARM_POSE - np.copyto(physics.data.ctrl, START_ARM_POSE) - assert BOX_POSE[0] is not None - physics.named.data.qpos[-7:] = BOX_POSE[0] - # print(f"{BOX_POSE=}") - super().initialize_episode(physics) - - @staticmethod - def get_env_state(physics): - env_state = physics.data.qpos.copy()[16:] - return env_state - - def get_reward(self, physics): - # return whether left gripper is holding the box - all_contact_pairs = [] - for i_contact in range(physics.data.ncon): - id_geom_1 = physics.data.contact[i_contact].geom1 - id_geom_2 = physics.data.contact[i_contact].geom2 - name_geom_1 = physics.model.id2name(id_geom_1, "geom") - name_geom_2 = physics.model.id2name(id_geom_2, "geom") - contact_pair = (name_geom_1, name_geom_2) - all_contact_pairs.append(contact_pair) - - touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs - touch_table = ("red_box", "table") in all_contact_pairs - - reward = 0 - if touch_right_gripper: - reward = 1 - if touch_right_gripper and not touch_table: # lifted - reward = 2 - if touch_left_gripper: # attempted transfer - reward = 3 - if touch_left_gripper and not touch_table: # successful transfer - reward = 4 - return reward - - -class InsertionTask(BimanualViperXTask): - def __init__(self, random=None): - super().__init__(random=random) - self.max_reward = 4 - - def initialize_episode(self, physics): - """Sets the state of the environment at the start of each episode.""" - # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside - # reset qpos, control and box position - with physics.reset_context(): - physics.named.data.qpos[:16] = START_ARM_POSE - np.copyto(physics.data.ctrl, START_ARM_POSE) - assert BOX_POSE[0] is not None - physics.named.data.qpos[-7 * 2 :] = BOX_POSE[0] # two objects - # print(f"{BOX_POSE=}") - super().initialize_episode(physics) - - @staticmethod - def get_env_state(physics): - env_state = physics.data.qpos.copy()[16:] - return env_state - - def get_reward(self, physics): - # return whether peg touches the pin - all_contact_pairs = [] - for i_contact in range(physics.data.ncon): - id_geom_1 = physics.data.contact[i_contact].geom1 - id_geom_2 = physics.data.contact[i_contact].geom2 - name_geom_1 = physics.model.id2name(id_geom_1, "geom") - name_geom_2 = physics.model.id2name(id_geom_2, "geom") - contact_pair = (name_geom_1, name_geom_2) - all_contact_pairs.append(contact_pair) - - touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs - touch_left_gripper = ( - ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - ) - - peg_touch_table = ("red_peg", "table") in all_contact_pairs - socket_touch_table = ( - ("socket-1", "table") in all_contact_pairs - or ("socket-2", "table") in all_contact_pairs - or ("socket-3", "table") in all_contact_pairs - or ("socket-4", "table") in all_contact_pairs - ) - peg_touch_socket = ( - ("red_peg", "socket-1") in all_contact_pairs - or ("red_peg", "socket-2") in all_contact_pairs - or ("red_peg", "socket-3") in all_contact_pairs - or ("red_peg", "socket-4") in all_contact_pairs - ) - pin_touched = ("red_peg", "pin") in all_contact_pairs - - reward = 0 - if touch_left_gripper and touch_right_gripper: # touch both - reward = 1 - if ( - touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table) - ): # grasp both - reward = 2 - if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching - reward = 3 - if pin_touched: # successful insertion - reward = 4 - return reward diff --git a/lerobot/common/envs/aloha/tasks/sim_end_effector.py b/lerobot/common/envs/aloha/tasks/sim_end_effector.py deleted file mode 100644 index d93c8330..00000000 --- a/lerobot/common/envs/aloha/tasks/sim_end_effector.py +++ /dev/null @@ -1,263 +0,0 @@ -import collections - -import numpy as np -from dm_control.suite import base - -from lerobot.common.envs.aloha.constants import ( - PUPPET_GRIPPER_POSITION_CLOSE, - START_ARM_POSE, - normalize_puppet_gripper_position, - normalize_puppet_gripper_velocity, - unnormalize_puppet_gripper_position, -) -from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose - -""" -Environment for simulated robot bi-manual manipulation, with end-effector control. -Action space: [left_arm_pose (7), # position and quaternion for end effector - left_gripper_positions (1), # normalized gripper position (0: close, 1: open) - right_arm_pose (7), # position and quaternion for end effector - right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) - -Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position - left_gripper_position (1), # normalized gripper position (0: close, 1: open) - right_arm_qpos (6), # absolute joint position - right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) - "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) - left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) - right_arm_qvel (6), # absolute joint velocity (rad) - right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) - "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' -""" - - -class BimanualViperXEndEffectorTask(base.Task): - def __init__(self, random=None): - super().__init__(random=random) - - def before_step(self, action, physics): - a_len = len(action) // 2 - action_left = action[:a_len] - action_right = action[a_len:] - - # set mocap position and quat - # left - np.copyto(physics.data.mocap_pos[0], action_left[:3]) - np.copyto(physics.data.mocap_quat[0], action_left[3:7]) - # right - np.copyto(physics.data.mocap_pos[1], action_right[:3]) - np.copyto(physics.data.mocap_quat[1], action_right[3:7]) - - # set gripper - g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7]) - g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7]) - np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl])) - - def initialize_robots(self, physics): - # reset joint position - physics.named.data.qpos[:16] = START_ARM_POSE - - # reset mocap to align with end effector - # to obtain these numbers: - # (1) make an ee_sim env and reset to the same start_pose - # (2) get env._physics.named.data.xpos['vx300s_left/gripper_link'] - # get env._physics.named.data.xquat['vx300s_left/gripper_link'] - # repeat the same for right side - np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084]) - np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0]) - # right - np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084])) - np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0]) - - # reset gripper control - close_gripper_control = np.array( - [ - PUPPET_GRIPPER_POSITION_CLOSE, - -PUPPET_GRIPPER_POSITION_CLOSE, - PUPPET_GRIPPER_POSITION_CLOSE, - -PUPPET_GRIPPER_POSITION_CLOSE, - ] - ) - np.copyto(physics.data.ctrl, close_gripper_control) - - def initialize_episode(self, physics): - """Sets the state of the environment at the start of each episode.""" - super().initialize_episode(physics) - - @staticmethod - def get_qpos(physics): - qpos_raw = physics.data.qpos.copy() - left_qpos_raw = qpos_raw[:8] - right_qpos_raw = qpos_raw[8:16] - left_arm_qpos = left_qpos_raw[:6] - right_arm_qpos = right_qpos_raw[:6] - left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])] - right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])] - return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) - - @staticmethod - def get_qvel(physics): - qvel_raw = physics.data.qvel.copy() - left_qvel_raw = qvel_raw[:8] - right_qvel_raw = qvel_raw[8:16] - left_arm_qvel = left_qvel_raw[:6] - right_arm_qvel = right_qvel_raw[:6] - left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])] - right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])] - return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) - - @staticmethod - def get_env_state(physics): - raise NotImplementedError - - def get_observation(self, physics): - # note: it is important to do .copy() - obs = collections.OrderedDict() - obs["qpos"] = self.get_qpos(physics) - obs["qvel"] = self.get_qvel(physics) - obs["env_state"] = self.get_env_state(physics) - obs["images"] = {} - obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top") - obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle") - obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close") - # used in scripted policy to obtain starting pose - obs["mocap_pose_left"] = np.concatenate( - [physics.data.mocap_pos[0], physics.data.mocap_quat[0]] - ).copy() - obs["mocap_pose_right"] = np.concatenate( - [physics.data.mocap_pos[1], physics.data.mocap_quat[1]] - ).copy() - - # used when replaying joint trajectory - obs["gripper_ctrl"] = physics.data.ctrl.copy() - return obs - - def get_reward(self, physics): - raise NotImplementedError - - -class TransferCubeEndEffectorTask(BimanualViperXEndEffectorTask): - def __init__(self, random=None): - super().__init__(random=random) - self.max_reward = 4 - - def initialize_episode(self, physics): - """Sets the state of the environment at the start of each episode.""" - self.initialize_robots(physics) - # randomize box position - cube_pose = sample_box_pose() - box_start_idx = physics.model.name2id("red_box_joint", "joint") - np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose) - # print(f"randomized cube position to {cube_position}") - - super().initialize_episode(physics) - - @staticmethod - def get_env_state(physics): - env_state = physics.data.qpos.copy()[16:] - return env_state - - def get_reward(self, physics): - # return whether left gripper is holding the box - all_contact_pairs = [] - for i_contact in range(physics.data.ncon): - id_geom_1 = physics.data.contact[i_contact].geom1 - id_geom_2 = physics.data.contact[i_contact].geom2 - name_geom_1 = physics.model.id2name(id_geom_1, "geom") - name_geom_2 = physics.model.id2name(id_geom_2, "geom") - contact_pair = (name_geom_1, name_geom_2) - all_contact_pairs.append(contact_pair) - - touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs - touch_table = ("red_box", "table") in all_contact_pairs - - reward = 0 - if touch_right_gripper: - reward = 1 - if touch_right_gripper and not touch_table: # lifted - reward = 2 - if touch_left_gripper: # attempted transfer - reward = 3 - if touch_left_gripper and not touch_table: # successful transfer - reward = 4 - return reward - - -class InsertionEndEffectorTask(BimanualViperXEndEffectorTask): - def __init__(self, random=None): - super().__init__(random=random) - self.max_reward = 4 - - def initialize_episode(self, physics): - """Sets the state of the environment at the start of each episode.""" - self.initialize_robots(physics) - # randomize peg and socket position - peg_pose, socket_pose = sample_insertion_pose() - - def id2index(j_id): - return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky - - peg_start_id = physics.model.name2id("red_peg_joint", "joint") - peg_start_idx = id2index(peg_start_id) - np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose) - # print(f"randomized cube position to {cube_position}") - - socket_start_id = physics.model.name2id("blue_socket_joint", "joint") - socket_start_idx = id2index(socket_start_id) - np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose) - # print(f"randomized cube position to {cube_position}") - - super().initialize_episode(physics) - - @staticmethod - def get_env_state(physics): - env_state = physics.data.qpos.copy()[16:] - return env_state - - def get_reward(self, physics): - # return whether peg touches the pin - all_contact_pairs = [] - for i_contact in range(physics.data.ncon): - id_geom_1 = physics.data.contact[i_contact].geom1 - id_geom_2 = physics.data.contact[i_contact].geom2 - name_geom_1 = physics.model.id2name(id_geom_1, "geom") - name_geom_2 = physics.model.id2name(id_geom_2, "geom") - contact_pair = (name_geom_1, name_geom_2) - all_contact_pairs.append(contact_pair) - - touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs - touch_left_gripper = ( - ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - ) - - peg_touch_table = ("red_peg", "table") in all_contact_pairs - socket_touch_table = ( - ("socket-1", "table") in all_contact_pairs - or ("socket-2", "table") in all_contact_pairs - or ("socket-3", "table") in all_contact_pairs - or ("socket-4", "table") in all_contact_pairs - ) - peg_touch_socket = ( - ("red_peg", "socket-1") in all_contact_pairs - or ("red_peg", "socket-2") in all_contact_pairs - or ("red_peg", "socket-3") in all_contact_pairs - or ("red_peg", "socket-4") in all_contact_pairs - ) - pin_touched = ("red_peg", "pin") in all_contact_pairs - - reward = 0 - if touch_left_gripper and touch_right_gripper: # touch both - reward = 1 - if ( - touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table) - ): # grasp both - reward = 2 - if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching - reward = 3 - if pin_touched: # successful insertion - reward = 4 - return reward diff --git a/lerobot/common/envs/aloha/utils.py b/lerobot/common/envs/aloha/utils.py deleted file mode 100644 index 5ac8b955..00000000 --- a/lerobot/common/envs/aloha/utils.py +++ /dev/null @@ -1,39 +0,0 @@ -import numpy as np - - -def sample_box_pose(): - x_range = [0.0, 0.2] - y_range = [0.4, 0.6] - z_range = [0.05, 0.05] - - ranges = np.vstack([x_range, y_range, z_range]) - cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) - - cube_quat = np.array([1, 0, 0, 0]) - return np.concatenate([cube_position, cube_quat]) - - -def sample_insertion_pose(): - # Peg - x_range = [0.1, 0.2] - y_range = [0.4, 0.6] - z_range = [0.05, 0.05] - - ranges = np.vstack([x_range, y_range, z_range]) - peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) - - peg_quat = np.array([1, 0, 0, 0]) - peg_pose = np.concatenate([peg_position, peg_quat]) - - # Socket - x_range = [-0.2, -0.1] - y_range = [0.4, 0.6] - z_range = [0.05, 0.05] - - ranges = np.vstack([x_range, y_range, z_range]) - socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) - - socket_quat = np.array([1, 0, 0, 0]) - socket_pose = np.concatenate([socket_position, socket_quat]) - - return peg_pose, socket_pose diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 855e073b..d5571935 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,64 +1,42 @@ -from torchrl.envs import SerialEnv -from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv +import importlib + +import gymnasium as gym -def make_env(cfg, transform=None): +def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv: """ - Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying - environments. The env therefore returns batches.` + Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and + returns batched observation, reward, terminated, truncated of `num_parallel_envs` items. """ - kwargs = { - "frame_skip": cfg.env.action_repeat, - "from_pixels": cfg.env.from_pixels, - "pixels_only": cfg.env.pixels_only, - "image_size": cfg.env.image_size, - "num_prev_obs": cfg.n_obs_steps - 1, + "obs_type": "pixels_agent_pos", + "render_mode": "rgb_array", + "max_episode_steps": cfg.env.episode_length, + "visualization_width": 384, + "visualization_height": 384, } - if cfg.env.name == "simxarm": - from lerobot.common.envs.simxarm.env import SimxarmEnv + package_name = f"gym_{cfg.env.name}" - kwargs["task"] = cfg.env.task - clsfunc = SimxarmEnv - elif cfg.env.name == "pusht": - from lerobot.common.envs.pusht.env import PushtEnv + try: + importlib.import_module(package_name) + except ModuleNotFoundError as e: + print( + f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`" + ) + raise e - # assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range." + gym_handle = f"{package_name}/{cfg.env.task}" - clsfunc = PushtEnv - elif cfg.env.name == "aloha": - from lerobot.common.envs.aloha.env import AlohaEnv - - kwargs["task"] = cfg.env.task - clsfunc = AlohaEnv + if num_parallel_envs == 0: + # non-batched version of the env that returns an observation of shape (c) + env = gym.make(gym_handle, disable_env_checker=True, **kwargs) else: - raise ValueError(cfg.env.name) - - def _make_env(seed): - nonlocal kwargs - kwargs["seed"] = seed - env = clsfunc(**kwargs) - - # limit rollout to max_steps - env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) - - if transform is not None: - # useful to add normalization - if isinstance(transform, Compose): - for tf in transform: - env.append_transform(tf.clone()) - elif isinstance(transform, Transform): - env.append_transform(transform.clone()) - else: - raise NotImplementedError() - - return env - - return SerialEnv( - cfg.rollout_batch_size, - create_env_fn=_make_env, - create_env_kwargs=[ - {"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) - ], - ) + # batched version of the env that returns an observation of shape (b, c) + env = gym.vector.SyncVectorEnv( + [ + lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs) + for _ in range(num_parallel_envs) + ] + ) + return env diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py deleted file mode 100644 index 5f7fb2c3..00000000 --- a/lerobot/common/envs/pusht/env.py +++ /dev/null @@ -1,245 +0,0 @@ -import importlib -import logging -from collections import deque -from typing import Optional - -import cv2 -import numpy as np -import torch -from tensordict import TensorDict -from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) -from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform - -from lerobot.common.envs.abstract import AbstractEnv -from lerobot.common.utils import set_global_seed - -_has_gym = importlib.util.find_spec("gymnasium") is not None - - -class PushtEnv(AbstractEnv): - name = "pusht" - available_tasks = ["pusht"] - _reset_warning_issued = False - - def __init__( - self, - task="pusht", - frame_skip: int = 1, - from_pixels: bool = False, - pixels_only: bool = False, - image_size=None, - seed=1337, - device="cpu", - num_prev_obs=1, - num_prev_action=0, - ): - super().__init__( - task=task, - frame_skip=frame_skip, - from_pixels=from_pixels, - pixels_only=pixels_only, - image_size=image_size, - seed=seed, - device=device, - num_prev_obs=num_prev_obs, - num_prev_action=num_prev_action, - ) - - def _make_env(self): - if not _has_gym: - raise ImportError("Cannot import gymnasium.") - - # TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on) - # from lerobot.common.envs.pusht.pusht_env import PushTEnv - - if not self.from_pixels: - raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv") - from lerobot.common.envs.pusht.pusht_image_env import PushTImageEnv - - self._env = PushTImageEnv(render_size=self.image_size) - - def render(self, mode="rgb_array", width=96, height=96, with_marker=True): - """ - with_marker adds a cursor showing the targeted action for the controller. - """ - if width != height: - raise NotImplementedError() - tmp = self._env.render_size - if width != self._env.render_size: - self._env.render_cache = None - self._env.render_size = width - out = self._env.render(mode).copy() - if with_marker and self._env.latest_action is not None: - action = np.array(self._env.latest_action) - coord = (action / 512 * self._env.render_size).astype(np.int32) - marker_size = int(8 / 96 * self._env.render_size) - thickness = int(1 / 96 * self._env.render_size) - cv2.drawMarker( - out, - coord, - color=(255, 0, 0), - markerType=cv2.MARKER_CROSS, - markerSize=marker_size, - thickness=thickness, - ) - self._env.render_size = tmp - return out - - def _format_raw_obs(self, raw_obs): - if self.from_pixels: - image = torch.from_numpy(raw_obs["image"]) - obs = {"image": image} - - if not self.pixels_only: - obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32) - else: - # TODO: - obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)} - - return obs - - def _reset(self, tensordict: Optional[TensorDict] = None): - if tensordict is not None and not PushtEnv._reset_warning_issued: - logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") - PushtEnv._reset_warning_issued = True - - # Seed the environment and update the seed to be used for the next reset. - self._next_seed = self.set_seed(self._next_seed) - raw_obs = self._env.reset() - - obs = self._format_raw_obs(raw_obs) - - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue = deque( - [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) - if "state" in obs: - self._prev_obs_state_queue = deque( - [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs - - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "done": torch.tensor([False], dtype=torch.bool), - }, - batch_size=[], - ) - - return td - - def _step(self, tensordict: TensorDict): - td = tensordict - action = td["action"].numpy() - assert action.ndim == 1 - # TODO(rcadene): add info["is_success"] and info["success"] ? - - raw_obs, reward, done, info = self._env.step(action) - - obs = self._format_raw_obs(raw_obs) - - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue.append(obs["image"]) - stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) - if "state" in obs: - self._prev_obs_state_queue.append(obs["state"]) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs - - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "reward": torch.tensor([reward], dtype=torch.float32), - # success and done are true when coverage > self.success_threshold in env - "done": torch.tensor([done], dtype=torch.bool), - "success": torch.tensor([done], dtype=torch.bool), - }, - batch_size=[], - ) - return td - - def _make_spec(self): - obs = {} - if self.from_pixels: - image_shape = (3, self.image_size, self.image_size) - if self.num_prev_obs > 0: - image_shape = (self.num_prev_obs + 1, *image_shape) - - obs["image"] = BoundedTensorSpec( - low=0, - high=255, - shape=image_shape, - dtype=torch.uint8, - device=self.device, - ) - if not self.pixels_only: - state_shape = self._env.observation_space["agent_pos"].shape - if self.num_prev_obs > 0: - state_shape = (self.num_prev_obs + 1, *state_shape) - - obs["state"] = BoundedTensorSpec( - low=0, - high=512, - shape=state_shape, - dtype=torch.float32, - device=self.device, - ) - else: - # TODO(rcadene): add observation_space achieved_goal and desired_goal? - state_shape = self._env.observation_space["observation"].shape - if self.num_prev_obs > 0: - state_shape = (self.num_prev_obs + 1, *state_shape) - - obs["state"] = UnboundedContinuousTensorSpec( - # TODO: - shape=state_shape, - dtype=torch.float32, - device=self.device, - ) - self.observation_spec = CompositeSpec({"observation": obs}) - - self.action_spec = _gym_to_torchrl_spec_transform( - self._env.action_space, - device=self.device, - ) - - self.reward_spec = UnboundedContinuousTensorSpec( - shape=(1,), - dtype=torch.float32, - device=self.device, - ) - - self.done_spec = CompositeSpec( - { - "done": DiscreteTensorSpec( - 2, - shape=(1,), - dtype=torch.bool, - device=self.device, - ), - "success": DiscreteTensorSpec( - 2, - shape=(1,), - dtype=torch.bool, - device=self.device, - ), - } - ) - - def _set_seed(self, seed: Optional[int]): - # Set global seed. - set_global_seed(seed) - # Set PushTImageEnv seed as it relies on it's own internal _seed attribute. - self._env.seed(seed) diff --git a/lerobot/common/envs/pusht/pusht_env.py b/lerobot/common/envs/pusht/pusht_env.py deleted file mode 100644 index 6ef70aec..00000000 --- a/lerobot/common/envs/pusht/pusht_env.py +++ /dev/null @@ -1,378 +0,0 @@ -import collections - -import cv2 -import gymnasium as gym -import numpy as np -import pygame -import pymunk -import pymunk.pygame_util -import shapely.geometry as sg -import skimage.transform as st -from gymnasium import spaces -from pymunk.vec2d import Vec2d - -from lerobot.common.envs.pusht.pymunk_override import DrawOptions - - -def pymunk_to_shapely(body, shapes): - geoms = [] - for shape in shapes: - if isinstance(shape, pymunk.shapes.Poly): - verts = [body.local_to_world(v) for v in shape.get_vertices()] - verts += [verts[0]] - geoms.append(sg.Polygon(verts)) - else: - raise RuntimeError(f"Unsupported shape type {type(shape)}") - geom = sg.MultiPolygon(geoms) - return geom - - -class PushTEnv(gym.Env): - metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10} - reward_range = (0.0, 1.0) - - def __init__( - self, - legacy=True, # compatibility with original - block_cog=None, - damping=None, - render_action=True, - render_size=96, - reset_to_state=None, - ): - self._seed = None - self.seed() - self.window_size = ws = 512 # The size of the PyGame window - self.render_size = render_size - self.sim_hz = 100 - # Local controller params. - self.k_p, self.k_v = 100, 20 # PD control.z - self.control_hz = self.metadata["video.frames_per_second"] - # legcay set_state for data compatibility - self.legacy = legacy - - # agent_pos, block_pos, block_angle - self.observation_space = spaces.Box( - low=np.array([0, 0, 0, 0, 0], dtype=np.float64), - high=np.array([ws, ws, ws, ws, np.pi * 2], dtype=np.float64), - shape=(5,), - dtype=np.float64, - ) - - # positional goal for agent - self.action_space = spaces.Box( - low=np.array([0, 0], dtype=np.float64), - high=np.array([ws, ws], dtype=np.float64), - shape=(2,), - dtype=np.float64, - ) - - self.block_cog = block_cog - self.damping = damping - self.render_action = render_action - - """ - If human-rendering is used, `self.window` will be a reference - to the window that we draw to. `self.clock` will be a clock that is used - to ensure that the environment is rendered at the correct framerate in - human-mode. They will remain `None` until human-mode is used for the - first time. - """ - self.window = None - self.clock = None - self.screen = None - - self.space = None - self.teleop = None - self.render_buffer = None - self.latest_action = None - self.reset_to_state = reset_to_state - - def reset(self): - seed = self._seed - self._setup() - if self.block_cog is not None: - self.block.center_of_gravity = self.block_cog - if self.damping is not None: - self.space.damping = self.damping - - # use legacy RandomState for compatibility - state = self.reset_to_state - if state is None: - rs = np.random.RandomState(seed=seed) - state = np.array( - [ - rs.randint(50, 450), - rs.randint(50, 450), - rs.randint(100, 400), - rs.randint(100, 400), - rs.randn() * 2 * np.pi - np.pi, - ] - ) - self._set_state(state) - - observation = self._get_obs() - return observation - - def step(self, action): - dt = 1.0 / self.sim_hz - self.n_contact_points = 0 - n_steps = self.sim_hz // self.control_hz - if action is not None: - self.latest_action = action - for _ in range(n_steps): - # Step PD control. - # self.agent.velocity = self.k_p * (act - self.agent.position) # P control works too. - acceleration = self.k_p * (action - self.agent.position) + self.k_v * ( - Vec2d(0, 0) - self.agent.velocity - ) - self.agent.velocity += acceleration * dt - - # Step physics. - self.space.step(dt) - - # compute reward - goal_body = self._get_goal_pose_body(self.goal_pose) - goal_geom = pymunk_to_shapely(goal_body, self.block.shapes) - block_geom = pymunk_to_shapely(self.block, self.block.shapes) - - intersection_area = goal_geom.intersection(block_geom).area - goal_area = goal_geom.area - coverage = intersection_area / goal_area - reward = np.clip(coverage / self.success_threshold, 0, 1) - done = coverage > self.success_threshold - - observation = self._get_obs() - info = self._get_info() - - return observation, reward, done, info - - def render(self, mode): - return self._render_frame(mode) - - def teleop_agent(self): - TeleopAgent = collections.namedtuple("TeleopAgent", ["act"]) - - def act(obs): - act = None - mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen) - if self.teleop or (mouse_position - self.agent.position).length < 30: - self.teleop = True - act = mouse_position - return act - - return TeleopAgent(act) - - def _get_obs(self): - obs = np.array( - tuple(self.agent.position) + tuple(self.block.position) + (self.block.angle % (2 * np.pi),) - ) - return obs - - def _get_goal_pose_body(self, pose): - mass = 1 - inertia = pymunk.moment_for_box(mass, (50, 100)) - body = pymunk.Body(mass, inertia) - # preserving the legacy assignment order for compatibility - # the order here doesn't matter somehow, maybe because CoM is aligned with body origin - body.position = pose[:2].tolist() - body.angle = pose[2] - return body - - def _get_info(self): - n_steps = self.sim_hz // self.control_hz - n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps)) - info = { - "pos_agent": np.array(self.agent.position), - "vel_agent": np.array(self.agent.velocity), - "block_pose": np.array(list(self.block.position) + [self.block.angle]), - "goal_pose": self.goal_pose, - "n_contacts": n_contact_points_per_step, - } - return info - - def _render_frame(self, mode): - if self.window is None and mode == "human": - pygame.init() - pygame.display.init() - self.window = pygame.display.set_mode((self.window_size, self.window_size)) - if self.clock is None and mode == "human": - self.clock = pygame.time.Clock() - - canvas = pygame.Surface((self.window_size, self.window_size)) - canvas.fill((255, 255, 255)) - self.screen = canvas - - draw_options = DrawOptions(canvas) - - # Draw goal pose. - goal_body = self._get_goal_pose_body(self.goal_pose) - for shape in self.block.shapes: - goal_points = [ - pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface) - for v in shape.get_vertices() - ] - goal_points += [goal_points[0]] - pygame.draw.polygon(canvas, self.goal_color, goal_points) - - # Draw agent and block. - self.space.debug_draw(draw_options) - - if mode == "human": - # The following line copies our drawings from `canvas` to the visible window - self.window.blit(canvas, canvas.get_rect()) - pygame.event.pump() - pygame.display.update() - - # the clock is already ticked during in step for "human" - - img = np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)) - img = cv2.resize(img, (self.render_size, self.render_size)) - if self.render_action and self.latest_action is not None: - action = np.array(self.latest_action) - coord = (action / 512 * 96).astype(np.int32) - marker_size = int(8 / 96 * self.render_size) - thickness = int(1 / 96 * self.render_size) - cv2.drawMarker( - img, - coord, - color=(255, 0, 0), - markerType=cv2.MARKER_CROSS, - markerSize=marker_size, - thickness=thickness, - ) - return img - - def close(self): - if self.window is not None: - pygame.display.quit() - pygame.quit() - - def seed(self, seed=None): - if seed is None: - seed = np.random.randint(0, 25536) - self._seed = seed - self.np_random = np.random.default_rng(seed) - - def _handle_collision(self, arbiter, space, data): - self.n_contact_points += len(arbiter.contact_point_set.points) - - def _set_state(self, state): - if isinstance(state, np.ndarray): - state = state.tolist() - pos_agent = state[:2] - pos_block = state[2:4] - rot_block = state[4] - self.agent.position = pos_agent - # setting angle rotates with respect to center of mass - # therefore will modify the geometric position - # if not the same as CoM - # therefore should be modified first. - if self.legacy: - # for compatibility with legacy data - self.block.position = pos_block - self.block.angle = rot_block - else: - self.block.angle = rot_block - self.block.position = pos_block - - # Run physics to take effect - self.space.step(1.0 / self.sim_hz) - - def _set_state_local(self, state_local): - agent_pos_local = state_local[:2] - block_pose_local = state_local[2:] - tf_img_obj = st.AffineTransform(translation=self.goal_pose[:2], rotation=self.goal_pose[2]) - tf_obj_new = st.AffineTransform(translation=block_pose_local[:2], rotation=block_pose_local[2]) - tf_img_new = st.AffineTransform(matrix=tf_img_obj.params @ tf_obj_new.params) - agent_pos_new = tf_img_new(agent_pos_local) - new_state = np.array(list(agent_pos_new[0]) + list(tf_img_new.translation) + [tf_img_new.rotation]) - self._set_state(new_state) - return new_state - - def _setup(self): - self.space = pymunk.Space() - self.space.gravity = 0, 0 - self.space.damping = 0 - self.teleop = False - self.render_buffer = [] - - # Add walls. - walls = [ - self._add_segment((5, 506), (5, 5), 2), - self._add_segment((5, 5), (506, 5), 2), - self._add_segment((506, 5), (506, 506), 2), - self._add_segment((5, 506), (506, 506), 2), - ] - self.space.add(*walls) - - # Add agent, block, and goal zone. - self.agent = self.add_circle((256, 400), 15) - self.block = self.add_tee((256, 300), 0) - self.goal_color = pygame.Color("LightGreen") - self.goal_pose = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) - - # Add collision handling - self.collision_handeler = self.space.add_collision_handler(0, 0) - self.collision_handeler.post_solve = self._handle_collision - self.n_contact_points = 0 - - self.max_score = 50 * 100 - self.success_threshold = 0.95 # 95% coverage. - - def _add_segment(self, a, b, radius): - shape = pymunk.Segment(self.space.static_body, a, b, radius) - shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names - return shape - - def add_circle(self, position, radius): - body = pymunk.Body(body_type=pymunk.Body.KINEMATIC) - body.position = position - body.friction = 1 - shape = pymunk.Circle(body, radius) - shape.color = pygame.Color("RoyalBlue") - self.space.add(body, shape) - return body - - def add_box(self, position, height, width): - mass = 1 - inertia = pymunk.moment_for_box(mass, (height, width)) - body = pymunk.Body(mass, inertia) - body.position = position - shape = pymunk.Poly.create_box(body, (height, width)) - shape.color = pygame.Color("LightSlateGray") - self.space.add(body, shape) - return body - - def add_tee(self, position, angle, scale=30, color="LightSlateGray", mask=None): - if mask is None: - mask = pymunk.ShapeFilter.ALL_MASKS() - mass = 1 - length = 4 - vertices1 = [ - (-length * scale / 2, scale), - (length * scale / 2, scale), - (length * scale / 2, 0), - (-length * scale / 2, 0), - ] - inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1) - vertices2 = [ - (-scale / 2, scale), - (-scale / 2, length * scale), - (scale / 2, length * scale), - (scale / 2, scale), - ] - inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1) - body = pymunk.Body(mass, inertia1 + inertia2) - shape1 = pymunk.Poly(body, vertices1) - shape2 = pymunk.Poly(body, vertices2) - shape1.color = pygame.Color(color) - shape2.color = pygame.Color(color) - shape1.filter = pymunk.ShapeFilter(mask=mask) - shape2.filter = pymunk.ShapeFilter(mask=mask) - body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2 - body.position = position - body.angle = angle - body.friction = 1 - self.space.add(body, shape1, shape2) - return body diff --git a/lerobot/common/envs/pusht/pusht_image_env.py b/lerobot/common/envs/pusht/pusht_image_env.py deleted file mode 100644 index 6547835a..00000000 --- a/lerobot/common/envs/pusht/pusht_image_env.py +++ /dev/null @@ -1,41 +0,0 @@ -import numpy as np -from gymnasium import spaces - -from lerobot.common.envs.pusht.pusht_env import PushTEnv - - -class PushTImageEnv(PushTEnv): - metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10} - - # Note: legacy defaults to True for compatibility with original - def __init__(self, legacy=True, block_cog=None, damping=None, render_size=96): - super().__init__( - legacy=legacy, block_cog=block_cog, damping=damping, render_size=render_size, render_action=False - ) - ws = self.window_size - self.observation_space = spaces.Dict( - { - "image": spaces.Box(low=0, high=1, shape=(3, render_size, render_size), dtype=np.float32), - "agent_pos": spaces.Box(low=0, high=ws, shape=(2,), dtype=np.float32), - } - ) - self.render_cache = None - - def _get_obs(self): - img = super()._render_frame(mode="rgb_array") - - agent_pos = np.array(self.agent.position) - img_obs = np.moveaxis(img, -1, 0) - obs = {"image": img_obs, "agent_pos": agent_pos} - - self.render_cache = img - - return obs - - def render(self, mode): - assert mode == "rgb_array" - - if self.render_cache is None: - self._get_obs() - - return self.render_cache diff --git a/lerobot/common/envs/pusht/pymunk_override.py b/lerobot/common/envs/pusht/pymunk_override.py deleted file mode 100644 index 7ad76237..00000000 --- a/lerobot/common/envs/pusht/pymunk_override.py +++ /dev/null @@ -1,244 +0,0 @@ -# ---------------------------------------------------------------------------- -# pymunk -# Copyright (c) 2007-2016 Victor Blomqvist -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ---------------------------------------------------------------------------- - -"""This submodule contains helper functions to help with quick prototyping -using pymunk together with pygame. - -Intended to help with debugging and prototyping, not for actual production use -in a full application. The methods contained in this module is opinionated -about your coordinate system and not in any way optimized. -""" - -__docformat__ = "reStructuredText" - -__all__ = [ - "DrawOptions", - "get_mouse_pos", - "to_pygame", - "from_pygame", - # "lighten", - "positive_y_is_up", -] - -from typing import Sequence, Tuple - -import numpy as np -import pygame -import pymunk -from pymunk.space_debug_draw_options import SpaceDebugColor -from pymunk.vec2d import Vec2d - -positive_y_is_up: bool = False -"""Make increasing values of y point upwards. - -When True:: - - y - ^ - | . (3, 3) - | - | . (2, 2) - | - +------ > x - -When False:: - - +------ > x - | - | . (2, 2) - | - | . (3, 3) - v - y - -""" - - -class DrawOptions(pymunk.SpaceDebugDrawOptions): - def __init__(self, surface: pygame.Surface) -> None: - """Draw a pymunk.Space on a pygame.Surface object. - - Typical usage:: - - >>> import pymunk - >>> surface = pygame.Surface((10,10)) - >>> space = pymunk.Space() - >>> options = pymunk.pygame_util.DrawOptions(surface) - >>> space.debug_draw(options) - - You can control the color of a shape by setting shape.color to the color - you want it drawn in:: - - >>> c = pymunk.Circle(None, 10) - >>> c.color = pygame.Color("pink") - - See pygame_util.demo.py for a full example - - Since pygame uses a coordinate system where y points down (in contrast - to many other cases), you either have to make the physics simulation - with Pymunk also behave in that way, or flip everything when you draw. - - The easiest is probably to just make the simulation behave the same - way as Pygame does. In that way all coordinates used are in the same - orientation and easy to reason about:: - - >>> space = pymunk.Space() - >>> space.gravity = (0, -1000) - >>> body = pymunk.Body() - >>> body.position = (0, 0) # will be positioned in the top left corner - >>> space.debug_draw(options) - - To flip the drawing its possible to set the module property - :py:data:`positive_y_is_up` to True. Then the pygame drawing will flip - the simulation upside down before drawing:: - - >>> positive_y_is_up = True - >>> body = pymunk.Body() - >>> body.position = (0, 0) - >>> # Body will be position in bottom left corner - - :Parameters: - surface : pygame.Surface - Surface that the objects will be drawn on - """ - self.surface = surface - super().__init__() - - def draw_circle( - self, - pos: Vec2d, - angle: float, - radius: float, - outline_color: SpaceDebugColor, - fill_color: SpaceDebugColor, - ) -> None: - p = to_pygame(pos, self.surface) - - pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0) - pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius - 4), 0) - - # circle_edge = pos + Vec2d(radius, 0).rotated(angle) - # p2 = to_pygame(circle_edge, self.surface) - # line_r = 2 if radius > 20 else 1 - # pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r) - - def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None: - p1 = to_pygame(a, self.surface) - p2 = to_pygame(b, self.surface) - - pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2]) - - def draw_fat_segment( - self, - a: Tuple[float, float], - b: Tuple[float, float], - radius: float, - outline_color: SpaceDebugColor, - fill_color: SpaceDebugColor, - ) -> None: - p1 = to_pygame(a, self.surface) - p2 = to_pygame(b, self.surface) - - r = round(max(1, radius * 2)) - pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r) - if r > 2: - orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])] - if orthog[0] == 0 and orthog[1] == 0: - return - scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5 - orthog[0] = round(orthog[0] * scale) - orthog[1] = round(orthog[1] * scale) - points = [ - (p1[0] - orthog[0], p1[1] - orthog[1]), - (p1[0] + orthog[0], p1[1] + orthog[1]), - (p2[0] + orthog[0], p2[1] + orthog[1]), - (p2[0] - orthog[0], p2[1] - orthog[1]), - ] - pygame.draw.polygon(self.surface, fill_color.as_int(), points) - pygame.draw.circle( - self.surface, - fill_color.as_int(), - (round(p1[0]), round(p1[1])), - round(radius), - ) - pygame.draw.circle( - self.surface, - fill_color.as_int(), - (round(p2[0]), round(p2[1])), - round(radius), - ) - - def draw_polygon( - self, - verts: Sequence[Tuple[float, float]], - radius: float, - outline_color: SpaceDebugColor, - fill_color: SpaceDebugColor, - ) -> None: - ps = [to_pygame(v, self.surface) for v in verts] - ps += [ps[0]] - - radius = 2 - pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps) - - if radius > 0: - for i in range(len(verts)): - a = verts[i] - b = verts[(i + 1) % len(verts)] - self.draw_fat_segment(a, b, radius, fill_color, fill_color) - - def draw_dot(self, size: float, pos: Tuple[float, float], color: SpaceDebugColor) -> None: - p = to_pygame(pos, self.surface) - pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0) - - -def get_mouse_pos(surface: pygame.Surface) -> Tuple[int, int]: - """Get position of the mouse pointer in pymunk coordinates.""" - p = pygame.mouse.get_pos() - return from_pygame(p, surface) - - -def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]: - """Convenience method to convert pymunk coordinates to pygame surface - local coordinates. - - Note that in case positive_y_is_up is False, this function won't actually do - anything except converting the point to integers. - """ - if positive_y_is_up: - return round(p[0]), surface.get_height() - round(p[1]) - else: - return round(p[0]), round(p[1]) - - -def from_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]: - """Convenience method to convert pygame surface local coordinates to - pymunk coordinates - """ - return to_pygame(p, surface) - - -def light_color(color: SpaceDebugColor): - color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255])) - color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3]) - return color diff --git a/lerobot/common/envs/simxarm/env.py b/lerobot/common/envs/simxarm/env.py deleted file mode 100644 index b81bf499..00000000 --- a/lerobot/common/envs/simxarm/env.py +++ /dev/null @@ -1,237 +0,0 @@ -import importlib -import logging -from collections import deque -from typing import Optional - -import einops -import numpy as np -import torch -from tensordict import TensorDict -from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) -from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform - -from lerobot.common.envs.abstract import AbstractEnv -from lerobot.common.utils import set_global_seed - -MAX_NUM_ACTIONS = 4 - -_has_gym = importlib.util.find_spec("gymnasium") is not None - - -class SimxarmEnv(AbstractEnv): - name = "simxarm" - available_tasks = ["lift"] - - def __init__( - self, - task, - frame_skip: int = 1, - from_pixels: bool = False, - pixels_only: bool = False, - image_size=None, - seed=1337, - device="cpu", - num_prev_obs=0, - num_prev_action=0, - ): - super().__init__( - task=task, - frame_skip=frame_skip, - from_pixels=from_pixels, - pixels_only=pixels_only, - image_size=image_size, - seed=seed, - device=device, - num_prev_obs=num_prev_obs, - num_prev_action=num_prev_action, - ) - - def _make_env(self): - if not _has_gym: - raise ImportError("Cannot import gymnasium.") - - import gymnasium - - from lerobot.common.envs.simxarm.simxarm import TASKS - - if self.task not in TASKS: - raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}") - - self._env = TASKS[self.task]["env"]() - - num_actions = len(TASKS[self.task]["action_space"]) - self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,)) - self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32) - if "w" not in TASKS[self.task]["action_space"]: - self._action_padding[-1] = 1.0 - - def render(self, mode="rgb_array", width=384, height=384): - return self._env.render(mode, width=width, height=height) - - def _format_raw_obs(self, raw_obs): - if self.from_pixels: - image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size) - image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W) - image = torch.tensor(image.copy(), dtype=torch.uint8) - - obs = {"image": image} - - if not self.pixels_only: - obs["state"] = torch.tensor(self._env.robot_state, dtype=torch.float32) - else: - obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)} - - # obs = TensorDict(obs, batch_size=[]) - return obs - - def _reset(self, tensordict: Optional[TensorDict] = None): - td = tensordict - if td is None or td.is_empty(): - raw_obs = self._env.reset() - - obs = self._format_raw_obs(raw_obs) - - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue = deque( - [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) - if "state" in obs: - self._prev_obs_state_queue = deque( - [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs - - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "done": torch.tensor([False], dtype=torch.bool), - }, - batch_size=[], - ) - else: - raise NotImplementedError() - - return td - - def _step(self, tensordict: TensorDict): - td = tensordict - action = td["action"].numpy() - # step expects shape=(4,) so we pad if necessary - action = np.concatenate([action, self._action_padding]) - # TODO(rcadene): add info["is_success"] and info["success"] ? - sum_reward = 0 - - if action.ndim == 1: - action = einops.repeat(action, "c -> t c", t=self.frame_skip) - else: - if self.frame_skip > 1: - raise NotImplementedError() - - num_action_steps = action.shape[0] - for i in range(num_action_steps): - raw_obs, reward, done, info = self._env.step(action[i]) - sum_reward += reward - - obs = self._format_raw_obs(raw_obs) - - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue.append(obs["image"]) - stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) - if "state" in obs: - self._prev_obs_state_queue.append(obs["state"]) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs - - td = TensorDict( - { - "observation": self._format_raw_obs(raw_obs), - "reward": torch.tensor([sum_reward], dtype=torch.float32), - "done": torch.tensor([done], dtype=torch.bool), - "success": torch.tensor([info["success"]], dtype=torch.bool), - }, - batch_size=[], - ) - return td - - def _make_spec(self): - obs = {} - if self.from_pixels: - image_shape = (3, self.image_size, self.image_size) - if self.num_prev_obs > 0: - image_shape = (self.num_prev_obs + 1, *image_shape) - - obs["image"] = BoundedTensorSpec( - low=0, - high=255, - shape=image_shape, - dtype=torch.uint8, - device=self.device, - ) - if not self.pixels_only: - state_shape = (len(self._env.robot_state),) - if self.num_prev_obs > 0: - state_shape = (self.num_prev_obs + 1, *state_shape) - - obs["state"] = UnboundedContinuousTensorSpec( - shape=state_shape, - dtype=torch.float32, - device=self.device, - ) - else: - # TODO(rcadene): add observation_space achieved_goal and desired_goal? - state_shape = self._env.observation_space["observation"].shape - if self.num_prev_obs > 0: - state_shape = (self.num_prev_obs + 1, *state_shape) - - obs["state"] = UnboundedContinuousTensorSpec( - # TODO: - shape=state_shape, - dtype=torch.float32, - device=self.device, - ) - self.observation_spec = CompositeSpec({"observation": obs}) - - self.action_spec = _gym_to_torchrl_spec_transform( - self._action_space, - device=self.device, - ) - - self.reward_spec = UnboundedContinuousTensorSpec( - shape=(1,), - dtype=torch.float32, - device=self.device, - ) - - self.done_spec = CompositeSpec( - { - "done": DiscreteTensorSpec( - 2, - shape=(1,), - dtype=torch.bool, - device=self.device, - ), - "success": DiscreteTensorSpec( - 2, - shape=(1,), - dtype=torch.bool, - device=self.device, - ), - } - ) - - def _set_seed(self, seed: Optional[int]): - set_global_seed(seed) - self._seed = seed - # TODO(aliberts): change self._reset so that it takes in a seed value - logging.warning("simxarm env is not properly seeded") diff --git a/lerobot/common/envs/simxarm/simxarm/__init__.py b/lerobot/common/envs/simxarm/simxarm/__init__.py deleted file mode 100644 index 903d6042..00000000 --- a/lerobot/common/envs/simxarm/simxarm/__init__.py +++ /dev/null @@ -1,166 +0,0 @@ -from collections import OrderedDict, deque - -import gymnasium as gym -import numpy as np -from gymnasium.wrappers import TimeLimit - -from lerobot.common.envs.simxarm.simxarm.tasks.base import Base as Base -from lerobot.common.envs.simxarm.simxarm.tasks.lift import Lift -from lerobot.common.envs.simxarm.simxarm.tasks.peg_in_box import PegInBox -from lerobot.common.envs.simxarm.simxarm.tasks.push import Push -from lerobot.common.envs.simxarm.simxarm.tasks.reach import Reach - -TASKS = OrderedDict( - ( - ( - "reach", - { - "env": Reach, - "action_space": "xyz", - "episode_length": 50, - "description": "Reach a target location with the end effector", - }, - ), - ( - "push", - { - "env": Push, - "action_space": "xyz", - "episode_length": 50, - "description": "Push a cube to a target location", - }, - ), - ( - "peg_in_box", - { - "env": PegInBox, - "action_space": "xyz", - "episode_length": 50, - "description": "Insert a peg into a box", - }, - ), - ( - "lift", - { - "env": Lift, - "action_space": "xyzw", - "episode_length": 50, - "description": "Lift a cube above a height threshold", - }, - ), - ) -) - - -class SimXarmWrapper(gym.Wrapper): - """ - A wrapper for the SimXarm environments. This wrapper is used to - convert the action and observation spaces to the correct format. - """ - - def __init__(self, env, task, obs_mode, image_size, action_repeat, frame_stack=1, channel_last=False): - super().__init__(env) - self._env = env - self.obs_mode = obs_mode - self.image_size = image_size - self.action_repeat = action_repeat - self.frame_stack = frame_stack - self._frames = deque([], maxlen=frame_stack) - self.channel_last = channel_last - self._max_episode_steps = task["episode_length"] // action_repeat - - image_shape = ( - (image_size, image_size, 3 * frame_stack) - if channel_last - else (3 * frame_stack, image_size, image_size) - ) - if obs_mode == "state": - self.observation_space = env.observation_space["observation"] - elif obs_mode == "rgb": - self.observation_space = gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8) - elif obs_mode == "all": - self.observation_space = gym.spaces.Dict( - state=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32), - rgb=gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8), - ) - else: - raise ValueError(f"Unknown obs_mode {obs_mode}. Must be one of [rgb, all, state]") - self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(len(task["action_space"]),)) - self.action_padding = np.zeros(4 - len(task["action_space"]), dtype=np.float32) - if "w" not in task["action_space"]: - self.action_padding[-1] = 1.0 - - def _render_obs(self): - obs = self.render(mode="rgb_array", width=self.image_size, height=self.image_size) - if not self.channel_last: - obs = obs.transpose(2, 0, 1) - return obs.copy() - - def _update_frames(self, reset=False): - pixels = self._render_obs() - self._frames.append(pixels) - if reset: - for _ in range(1, self.frame_stack): - self._frames.append(pixels) - assert len(self._frames) == self.frame_stack - - def transform_obs(self, obs, reset=False): - if self.obs_mode == "state": - return obs["observation"] - elif self.obs_mode == "rgb": - self._update_frames(reset=reset) - rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0) - return rgb_obs - elif self.obs_mode == "all": - self._update_frames(reset=reset) - rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0) - return OrderedDict((("rgb", rgb_obs), ("state", self.robot_state))) - else: - raise ValueError(f"Unknown obs_mode {self.obs_mode}. Must be one of [rgb, all, state]") - - def reset(self): - return self.transform_obs(self._env.reset(), reset=True) - - def step(self, action): - action = np.concatenate([action, self.action_padding]) - reward = 0.0 - for _ in range(self.action_repeat): - obs, r, done, info = self._env.step(action) - reward += r - return self.transform_obs(obs), reward, done, info - - def render(self, mode="rgb_array", width=384, height=384, **kwargs): - return self._env.render(mode, width=width, height=height) - - @property - def state(self): - return self._env.robot_state - - -def make(task, obs_mode="state", image_size=84, action_repeat=1, frame_stack=1, channel_last=False, seed=0): - """ - Create a new environment. - Args: - task (str): The task to create an environment for. Must be one of: - - 'reach' - - 'push' - - 'peg-in-box' - - 'lift' - obs_mode (str): The observation mode to use. Must be one of: - - 'state': Only state observations - - 'rgb': RGB images - - 'all': RGB images and state observations - image_size (int): The size of the image observations - action_repeat (int): The number of times to repeat the action - seed (int): The random seed to use - Returns: - gym.Env: The environment - """ - if task not in TASKS: - raise ValueError(f"Unknown task {task}. Must be one of {list(TASKS.keys())}") - env = TASKS[task]["env"]() - env = TimeLimit(env, TASKS[task]["episode_length"]) - env = SimXarmWrapper(env, TASKS[task], obs_mode, image_size, action_repeat, frame_stack, channel_last) - env.seed(seed) - - return env diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/__init__.py b/lerobot/common/envs/simxarm/simxarm/tasks/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/lift.xml b/lerobot/common/envs/simxarm/simxarm/tasks/assets/lift.xml deleted file mode 100644 index 92231f92..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/lift.xml +++ /dev/null @@ -1,53 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/base_link.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/base_link.stl deleted file mode 100644 index f1f52955..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/base_link.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:21fb81ae7fba19e3c6b2d2ca60c8051712ba273357287eb5a397d92d61c7a736 -size 1211434 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_inner.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_inner.stl deleted file mode 100644 index 6cb88945..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_inner.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:be68ce180d11630a667a5f37f4dffcc3feebe4217d4bb3912c813b6d9ca3ec66 -size 3284 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_inner2.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_inner2.stl deleted file mode 100644 index dab55ef5..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_inner2.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2c6448552bf6b1c4f17334d686a5320ce051bcdfe31431edf69303d8a570d1de -size 3284 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_outer.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_outer.stl deleted file mode 100644 index 21cf11fa..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/block_outer.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:748b9e197e6521914f18d1f6383a36f211136b3f33f2ad2a8c11b9f921c2cf86 -size 6284 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_finger.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_finger.stl deleted file mode 100644 index 6bf4e502..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_finger.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a44756eb72f9c214cb37e61dc209cd7073fdff3e4271a7423476ef6fd090d2d4 -size 242684 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_inner_knuckle.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_inner_knuckle.stl deleted file mode 100644 index 817c7e1d..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_inner_knuckle.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e8e48692ad26837bb3d6a97582c89784d09948fc09bfe4e5a59017859ff04dac -size 366284 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_outer_knuckle.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_outer_knuckle.stl deleted file mode 100644 index 010c0f3b..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/left_outer_knuckle.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:501665812b08d67e764390db781e839adc6896a9540301d60adf606f57648921 -size 22284 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link1.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link1.stl deleted file mode 100644 index f2b676f2..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link1.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:34b541122df84d2ef5fcb91b715eb19659dc15ad8d44a191dde481f780265636 -size 184184 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link2.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link2.stl deleted file mode 100644 index bf93580c..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link2.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:61e641cd47c169ecef779683332e00e4914db729bf02dfb61bfbe69351827455 -size 225584 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link3.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link3.stl deleted file mode 100644 index d316d233..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link3.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9e2798e7946dd70046c95455d5ba96392d0b54a6069caba91dc4ca66e1379b42 -size 237084 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link4.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link4.stl deleted file mode 100644 index f6d5fe94..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link4.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c757fee95f873191a0633c355c07a360032960771cabbd7593a6cdb0f1ffb089 -size 243684 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link5.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link5.stl deleted file mode 100644 index e037b8b9..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link5.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:715ad5787c5dab57589937fd47289882707b5e1eb997e340d567785b02f4ec90 -size 229084 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link6.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link6.stl deleted file mode 100644 index 198c5300..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link6.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:85b320aa420497827223d16d492bba8de091173374e361396fc7a5dad7bdb0cb -size 399384 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link7.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link7.stl deleted file mode 100644 index ce9a39ac..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link7.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:97115d848fbf802cb770cd9be639ae2af993103b9d9bbb0c50c943c738a36f18 -size 231684 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link_base.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link_base.stl deleted file mode 100644 index 110b9531..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/link_base.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f6fcbc18258090eb56c21cfb17baa5ae43abc98b1958cd366f3a73b9898fc7f0 -size 2106184 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_finger.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_finger.stl deleted file mode 100644 index 03f26e9a..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_finger.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c5dee87c7f37baf554b8456ebfe0b3e8ed0b22b8938bd1add6505c2ad6d32c7d -size 242684 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_inner_knuckle.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_inner_knuckle.stl deleted file mode 100644 index 8586f344..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_inner_knuckle.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b41dd2c2c550281bf78d7cc6fa117b14786700e5c453560a0cb5fd6dfa0ffb3e -size 366284 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_outer_knuckle.stl b/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_outer_knuckle.stl deleted file mode 100644 index ae7afc25..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/mesh/right_outer_knuckle.stl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:75ca1107d0a42a0f03802a9a49cab48419b31851ee8935f8f1ca06be1c1c91e8 -size 22284 diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/peg_in_box.xml b/lerobot/common/envs/simxarm/simxarm/tasks/assets/peg_in_box.xml deleted file mode 100644 index 0f85459f..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/peg_in_box.xml +++ /dev/null @@ -1,74 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/push.xml b/lerobot/common/envs/simxarm/simxarm/tasks/assets/push.xml deleted file mode 100644 index 42a78c8a..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/push.xml +++ /dev/null @@ -1,54 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/reach.xml b/lerobot/common/envs/simxarm/simxarm/tasks/assets/reach.xml deleted file mode 100644 index ded6d209..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/reach.xml +++ /dev/null @@ -1,48 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/shared.xml b/lerobot/common/envs/simxarm/simxarm/tasks/assets/shared.xml deleted file mode 100644 index ee56f8f0..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/shared.xml +++ /dev/null @@ -1,51 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/assets/xarm.xml b/lerobot/common/envs/simxarm/simxarm/tasks/assets/xarm.xml deleted file mode 100644 index 023474d6..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/assets/xarm.xml +++ /dev/null @@ -1,88 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/base.py b/lerobot/common/envs/simxarm/simxarm/tasks/base.py deleted file mode 100644 index b937b290..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/base.py +++ /dev/null @@ -1,145 +0,0 @@ -import os - -import mujoco -import numpy as np -from gymnasium_robotics.envs import robot_env - -from lerobot.common.envs.simxarm.simxarm.tasks import mocap - - -class Base(robot_env.MujocoRobotEnv): - """ - Superclass for all simxarm environments. - Args: - xml_name (str): name of the xml environment file - gripper_rotation (list): initial rotation of the gripper (given as a quaternion) - """ - - def __init__(self, xml_name, gripper_rotation=None): - if gripper_rotation is None: - gripper_rotation = [0, 1, 0, 0] - self.gripper_rotation = np.array(gripper_rotation, dtype=np.float32) - self.center_of_table = np.array([1.655, 0.3, 0.63625]) - self.max_z = 1.2 - self.min_z = 0.2 - super().__init__( - model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"), - n_substeps=20, - n_actions=4, - initial_qpos={}, - ) - - @property - def dt(self): - return self.n_substeps * self.model.opt.timestep - - @property - def eef(self): - return self._utils.get_site_xpos(self.model, self.data, "grasp") - - @property - def obj(self): - return self._utils.get_site_xpos(self.model, self.data, "object_site") - - @property - def robot_state(self): - gripper_angle = self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint") - return np.concatenate([self.eef, gripper_angle]) - - def is_success(self): - return NotImplementedError() - - def get_reward(self): - raise NotImplementedError() - - def _sample_goal(self): - raise NotImplementedError() - - def get_obs(self): - return self._get_obs() - - def _step_callback(self): - self._mujoco.mj_forward(self.model, self.data) - - def _limit_gripper(self, gripper_pos, pos_ctrl): - if gripper_pos[0] > self.center_of_table[0] - 0.105 + 0.15: - pos_ctrl[0] = min(pos_ctrl[0], 0) - if gripper_pos[0] < self.center_of_table[0] - 0.105 - 0.3: - pos_ctrl[0] = max(pos_ctrl[0], 0) - if gripper_pos[1] > self.center_of_table[1] + 0.3: - pos_ctrl[1] = min(pos_ctrl[1], 0) - if gripper_pos[1] < self.center_of_table[1] - 0.3: - pos_ctrl[1] = max(pos_ctrl[1], 0) - if gripper_pos[2] > self.max_z: - pos_ctrl[2] = min(pos_ctrl[2], 0) - if gripper_pos[2] < self.min_z: - pos_ctrl[2] = max(pos_ctrl[2], 0) - return pos_ctrl - - def _apply_action(self, action): - assert action.shape == (4,) - action = action.copy() - pos_ctrl, gripper_ctrl = action[:3], action[3] - pos_ctrl = self._limit_gripper( - self._utils.get_site_xpos(self.model, self.data, "grasp"), pos_ctrl - ) * (1 / self.n_substeps) - gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl]) - mocap.apply_action( - self.model, - self._model_names, - self.data, - np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl]), - ) - - def _render_callback(self): - self._mujoco.mj_forward(self.model, self.data) - - def _reset_sim(self): - self.data.time = self.initial_time - self.data.qpos[:] = np.copy(self.initial_qpos) - self.data.qvel[:] = np.copy(self.initial_qvel) - self._sample_goal() - self._mujoco.mj_step(self.model, self.data, nstep=10) - return True - - def _set_gripper(self, gripper_pos, gripper_rotation): - self._utils.set_mocap_pos(self.model, self.data, "robot0:mocap", gripper_pos) - self._utils.set_mocap_quat(self.model, self.data, "robot0:mocap", gripper_rotation) - self._utils.set_joint_qpos(self.model, self.data, "right_outer_knuckle_joint", 0) - self.data.qpos[10] = 0.0 - self.data.qpos[12] = 0.0 - - def _env_setup(self, initial_qpos): - for name, value in initial_qpos.items(): - self.data.set_joint_qpos(name, value) - mocap.reset(self.model, self.data) - mujoco.mj_forward(self.model, self.data) - self._sample_goal() - mujoco.mj_forward(self.model, self.data) - - def reset(self): - self._reset_sim() - return self._get_obs() - - def step(self, action): - assert action.shape == (4,) - assert self.action_space.contains(action), "{!r} ({}) invalid".format(action, type(action)) - self._apply_action(action) - self._mujoco.mj_step(self.model, self.data, nstep=2) - self._step_callback() - obs = self._get_obs() - reward = self.get_reward() - done = False - info = {"is_success": self.is_success(), "success": self.is_success()} - return obs, reward, done, info - - def render(self, mode="rgb_array", width=384, height=384): - self._render_callback() - # HACK - self.model.vis.global_.offwidth = width - self.model.vis.global_.offheight = height - return self.mujoco_renderer.render(mode) - - def close(self): - if self.mujoco_renderer is not None: - self.mujoco_renderer.close() diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/lift.py b/lerobot/common/envs/simxarm/simxarm/tasks/lift.py deleted file mode 100644 index 0b11196c..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/lift.py +++ /dev/null @@ -1,100 +0,0 @@ -import numpy as np - -from lerobot.common.envs.simxarm.simxarm import Base - - -class Lift(Base): - def __init__(self): - self._z_threshold = 0.15 - super().__init__("lift") - - @property - def z_target(self): - return self._init_z + self._z_threshold - - def is_success(self): - return self.obj[2] >= self.z_target - - def get_reward(self): - reach_dist = np.linalg.norm(self.obj - self.eef) - reach_dist_xy = np.linalg.norm(self.obj[:-1] - self.eef[:-1]) - pick_completed = self.obj[2] >= (self.z_target - 0.01) - obj_dropped = (self.obj[2] < (self._init_z + 0.005)) and (reach_dist > 0.02) - - # Reach - if reach_dist < 0.05: - reach_reward = -reach_dist + max(self._action[-1], 0) / 50 - elif reach_dist_xy < 0.05: - reach_reward = -reach_dist - else: - z_bonus = np.linalg.norm(np.linalg.norm(self.obj[-1] - self.eef[-1])) - reach_reward = -reach_dist - 2 * z_bonus - - # Pick - if pick_completed and not obj_dropped: - pick_reward = self.z_target - elif (reach_dist < 0.1) and (self.obj[2] > (self._init_z + 0.005)): - pick_reward = min(self.z_target, self.obj[2]) - else: - pick_reward = 0 - - return reach_reward / 100 + pick_reward - - def _get_obs(self): - eef_velp = self._utils.get_site_xvelp(self.model, self.data, "grasp") * self.dt - gripper_angle = self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint") - eef = self.eef - self.center_of_table - - obj = self.obj - self.center_of_table - obj_rot = self._utils.get_joint_qpos(self.model, self.data, "object_joint0")[-4:] - obj_velp = self._utils.get_site_xvelp(self.model, self.data, "object_site") * self.dt - obj_velr = self._utils.get_site_xvelr(self.model, self.data, "object_site") * self.dt - - obs = np.concatenate( - [ - eef, - eef_velp, - obj, - obj_rot, - obj_velp, - obj_velr, - eef - obj, - np.array( - [ - np.linalg.norm(eef - obj), - np.linalg.norm(eef[:-1] - obj[:-1]), - self.z_target, - self.z_target - obj[-1], - self.z_target - eef[-1], - ] - ), - gripper_angle, - ], - axis=0, - ) - return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": eef} - - def _sample_goal(self): - # Gripper - gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3) - super()._set_gripper(gripper_pos, self.gripper_rotation) - - # Object - object_pos = self.center_of_table - np.array([0.15, 0.10, 0.07]) - object_pos[0] += self.np_random.uniform(-0.05, 0.05, size=1) - object_pos[1] += self.np_random.uniform(-0.05, 0.05, size=1) - object_qpos = self._utils.get_joint_qpos(self.model, self.data, "object_joint0") - object_qpos[:3] = object_pos - self._utils.set_joint_qpos(self.model, self.data, "object_joint0", object_qpos) - self._init_z = object_pos[2] - - # Goal - return object_pos + np.array([0, 0, self._z_threshold]) - - def reset(self): - self._action = np.zeros(4) - return super().reset() - - def step(self, action): - self._action = action.copy() - return super().step(action) diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/mocap.py b/lerobot/common/envs/simxarm/simxarm/tasks/mocap.py deleted file mode 100644 index 4295bf19..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/mocap.py +++ /dev/null @@ -1,67 +0,0 @@ -# import mujoco_py -import mujoco -import numpy as np - - -def apply_action(model, model_names, data, action): - if model.nmocap > 0: - pos_action, gripper_action = np.split(action, (model.nmocap * 7,)) - if data.ctrl is not None: - for i in range(gripper_action.shape[0]): - data.ctrl[i] = gripper_action[i] - pos_action = pos_action.reshape(model.nmocap, 7) - pos_delta, quat_delta = pos_action[:, :3], pos_action[:, 3:] - reset_mocap2body_xpos(model, model_names, data) - data.mocap_pos[:] = data.mocap_pos + pos_delta - data.mocap_quat[:] = data.mocap_quat + quat_delta - - -def reset(model, data): - if model.nmocap > 0 and model.eq_data is not None: - for i in range(model.eq_data.shape[0]): - # if sim.model.eq_type[i] == mujoco_py.const.EQ_WELD: - if model.eq_type[i] == mujoco.mjtEq.mjEQ_WELD: - # model.eq_data[i, :] = np.array([0., 0., 0., 1., 0., 0., 0.]) - model.eq_data[i, :] = np.array( - [ - 0.0, - 0.0, - 0.0, - 1.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - ] - ) - # sim.forward() - mujoco.mj_forward(model, data) - - -def reset_mocap2body_xpos(model, model_names, data): - if model.eq_type is None or model.eq_obj1id is None or model.eq_obj2id is None: - return - - # For all weld constraints - for eq_type, obj1_id, obj2_id in zip(model.eq_type, model.eq_obj1id, model.eq_obj2id, strict=False): - # if eq_type != mujoco_py.const.EQ_WELD: - if eq_type != mujoco.mjtEq.mjEQ_WELD: - continue - # body2 = model.body_id2name(obj2_id) - body2 = model_names.body_id2name[obj2_id] - if body2 == "B0" or body2 == "B9" or body2 == "B1": - continue - mocap_id = model.body_mocapid[obj1_id] - if mocap_id != -1: - # obj1 is the mocap, obj2 is the welded body - body_idx = obj2_id - else: - # obj2 is the mocap, obj1 is the welded body - mocap_id = model.body_mocapid[obj2_id] - body_idx = obj1_id - assert mocap_id != -1 - data.mocap_pos[mocap_id][:] = data.xpos[body_idx] - data.mocap_quat[mocap_id][:] = data.xquat[body_idx] diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/peg_in_box.py b/lerobot/common/envs/simxarm/simxarm/tasks/peg_in_box.py deleted file mode 100644 index 42e41520..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/peg_in_box.py +++ /dev/null @@ -1,86 +0,0 @@ -import numpy as np - -from lerobot.common.envs.simxarm.simxarm import Base - - -class PegInBox(Base): - def __init__(self): - super().__init__("peg_in_box") - - def _reset_sim(self): - self._act_magnitude = 0 - super()._reset_sim() - for _ in range(10): - self._apply_action(np.array([0, 0, 0, 1], dtype=np.float32)) - self.sim.step() - - @property - def box(self): - return self.sim.data.get_site_xpos("box_site") - - def is_success(self): - return np.linalg.norm(self.obj - self.box) <= 0.05 - - def get_reward(self): - dist_xy = np.linalg.norm(self.obj[:2] - self.box[:2]) - dist_xyz = np.linalg.norm(self.obj - self.box) - return float(dist_xy <= 0.045) * (2 - 6 * dist_xyz) - 0.2 * np.square(self._act_magnitude) - dist_xy - - def _get_obs(self): - eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt - gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint") - eef, box = self.eef - self.center_of_table, self.box - self.center_of_table - - obj = self.obj - self.center_of_table - obj_rot = self.sim.data.get_joint_qpos("object_joint0")[-4:] - obj_velp = self.sim.data.get_site_xvelp("object_site") * self.dt - obj_velr = self.sim.data.get_site_xvelr("object_site") * self.dt - - obs = np.concatenate( - [ - eef, - eef_velp, - box, - obj, - obj_rot, - obj_velp, - obj_velr, - eef - box, - eef - obj, - obj - box, - np.array( - [ - np.linalg.norm(eef - box), - np.linalg.norm(eef - obj), - np.linalg.norm(obj - box), - gripper_angle, - ] - ), - ], - axis=0, - ) - return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": box} - - def _sample_goal(self): - # Gripper - gripper_pos = np.array([1.280, 0.295, 0.9]) + self.np_random.uniform(-0.05, 0.05, size=3) - super()._set_gripper(gripper_pos, self.gripper_rotation) - - # Object - object_pos = gripper_pos - np.array([0, 0, 0.06]) + self.np_random.uniform(-0.005, 0.005, size=3) - object_qpos = self.sim.data.get_joint_qpos("object_joint0") - object_qpos[:3] = object_pos - self.sim.data.set_joint_qpos("object_joint0", object_qpos) - - # Box - box_pos = np.array([1.61, 0.18, 0.58]) - box_pos[:2] += self.np_random.uniform(-0.11, 0.11, size=2) - box_qpos = self.sim.data.get_joint_qpos("box_joint0") - box_qpos[:3] = box_pos - self.sim.data.set_joint_qpos("box_joint0", box_qpos) - - return self.box - - def step(self, action): - self._act_magnitude = np.linalg.norm(action[:3]) - return super().step(action) diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/push.py b/lerobot/common/envs/simxarm/simxarm/tasks/push.py deleted file mode 100644 index 36c4a550..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/push.py +++ /dev/null @@ -1,78 +0,0 @@ -import numpy as np - -from lerobot.common.envs.simxarm.simxarm import Base - - -class Push(Base): - def __init__(self): - super().__init__("push") - - def _reset_sim(self): - self._act_magnitude = 0 - super()._reset_sim() - - def is_success(self): - return np.linalg.norm(self.obj - self.goal) <= 0.05 - - def get_reward(self): - dist = np.linalg.norm(self.obj - self.goal) - penalty = self._act_magnitude**2 - return -(dist + 0.15 * penalty) - - def _get_obs(self): - eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt - gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint") - eef, goal = self.eef - self.center_of_table, self.goal - self.center_of_table - - obj = self.obj - self.center_of_table - obj_rot = self.sim.data.get_joint_qpos("object_joint0")[-4:] - obj_velp = self.sim.data.get_site_xvelp("object_site") * self.dt - obj_velr = self.sim.data.get_site_xvelr("object_site") * self.dt - - obs = np.concatenate( - [ - eef, - eef_velp, - goal, - obj, - obj_rot, - obj_velp, - obj_velr, - eef - goal, - eef - obj, - obj - goal, - np.array( - [ - np.linalg.norm(eef - goal), - np.linalg.norm(eef - obj), - np.linalg.norm(obj - goal), - gripper_angle, - ] - ), - ], - axis=0, - ) - return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": goal} - - def _sample_goal(self): - # Gripper - gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3) - super()._set_gripper(gripper_pos, self.gripper_rotation) - - # Object - object_pos = self.center_of_table - np.array([0.25, 0, 0.07]) - object_pos[0] += self.np_random.uniform(-0.08, 0.08, size=1) - object_pos[1] += self.np_random.uniform(-0.08, 0.08, size=1) - object_qpos = self.sim.data.get_joint_qpos("object_joint0") - object_qpos[:3] = object_pos - self.sim.data.set_joint_qpos("object_joint0", object_qpos) - - # Goal - self.goal = np.array([1.600, 0.200, 0.545]) - self.goal[:2] += self.np_random.uniform(-0.1, 0.1, size=2) - self.sim.model.site_pos[self.sim.model.site_name2id("target0")] = self.goal - return self.goal - - def step(self, action): - self._act_magnitude = np.linalg.norm(action[:3]) - return super().step(action) diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/reach.py b/lerobot/common/envs/simxarm/simxarm/tasks/reach.py deleted file mode 100644 index 941a586f..00000000 --- a/lerobot/common/envs/simxarm/simxarm/tasks/reach.py +++ /dev/null @@ -1,44 +0,0 @@ -import numpy as np - -from lerobot.common.envs.simxarm.simxarm import Base - - -class Reach(Base): - def __init__(self): - super().__init__("reach") - - def _reset_sim(self): - self._act_magnitude = 0 - super()._reset_sim() - - def is_success(self): - return np.linalg.norm(self.eef - self.goal) <= 0.05 - - def get_reward(self): - dist = np.linalg.norm(self.eef - self.goal) - penalty = self._act_magnitude**2 - return -(dist + 0.15 * penalty) - - def _get_obs(self): - eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt - gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint") - eef, goal = self.eef - self.center_of_table, self.goal - self.center_of_table - obs = np.concatenate( - [eef, eef_velp, goal, eef - goal, np.array([np.linalg.norm(eef - goal), gripper_angle])], axis=0 - ) - return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": goal} - - def _sample_goal(self): - # Gripper - gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3) - super()._set_gripper(gripper_pos, self.gripper_rotation) - - # Goal - self.goal = np.array([1.550, 0.287, 0.580]) - self.goal[:2] += self.np_random.uniform(-0.125, 0.125, size=2) - self.sim.model.site_pos[self.sim.model.site_name2id("target0")] = self.goal - return self.goal - - def step(self, action): - self._act_magnitude = np.linalg.norm(action[:3]) - return super().step(action) diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py new file mode 100644 index 00000000..4d31ddb2 --- /dev/null +++ b/lerobot/common/envs/utils.py @@ -0,0 +1,41 @@ +import einops +import torch + +from lerobot.common.transforms import apply_inverse_transform + + +def preprocess_observation(observation, transform=None): + # map to expected inputs for the policy + obs = {} + + if isinstance(observation["pixels"], dict): + imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()} + else: + imgs = {"observation.image": observation["pixels"]} + + for imgkey, img in imgs.items(): + img = torch.from_numpy(img).float() + # convert to (b c h w) torch format + img = einops.rearrange(img, "b h w c -> b c h w") + obs[imgkey] = img + + # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos" + obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float() + + # apply same transforms as in training + if transform is not None: + for key in obs: + obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]]) + + return obs + + +def postprocess_action(action, transform=None): + action = action.to("cpu") + # action is a batch (num_env,action_dim) instead of an item (action_dim), + # we assume applying inverse transform on a batch works the same + action = apply_inverse_transform({"action": action}, transform)["action"].numpy() + assert ( + action.ndim == 2 + ), "we assume dimensions are respectively the number of parallel envs, action dimensions" + return action diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py deleted file mode 100644 index 6dc72bef..00000000 --- a/lerobot/common/policies/abstract.py +++ /dev/null @@ -1,82 +0,0 @@ -from collections import deque - -import torch -from torch import Tensor, nn - - -class AbstractPolicy(nn.Module): - """Base policy which all policies should be derived from. - - The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its - documentation for more information. - - Note: - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - 1. set the required class attributes: - - for classes inheriting from `AbstractDataset`: `available_datasets` - - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - - for classes inheriting from `AbstractPolicy`: `name` - 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) - 3. update variables in `tests/test_available.py` by importing your new class - """ - - name: str | None = None # same name should be used to instantiate the policy in factory.py - - def __init__(self, n_action_steps: int | None): - """ - n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single - action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then - adds that dimension. - """ - super().__init__() - assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute." - self.n_action_steps = n_action_steps - self.clear_action_queue() - - def update(self, replay_buffer, step): - """One step of the policy's learning algorithm.""" - raise NotImplementedError("Abstract method") - - def save(self, fp): - torch.save(self.state_dict(), fp) - - def load(self, fp): - d = torch.load(fp) - self.load_state_dict(d) - - def select_actions(self, observation) -> Tensor: - """Select an action (or trajectory of actions) based on an observation during rollout. - - If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of - actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions. - """ - raise NotImplementedError("Abstract method") - - def clear_action_queue(self): - """This should be called whenever the environment is reset.""" - if self.n_action_steps is not None: - self._action_queue = deque([], maxlen=self.n_action_steps) - - def forward(self, *args, **kwargs) -> Tensor: - """Inference step that makes multi-step policies compatible with their single-step environments. - - WARNING: In general, this should not be overriden. - - Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit - into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an - observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment - observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that - the subclass doesn't have to. - - This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made: - 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is - the action trajectory horizon and * is the action dimensions. - 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined. - """ - if self.n_action_steps is None: - return self.select_actions(*args, **kwargs) - if len(self._action_queue) == 0: - # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape - # (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1)) - return self._action_queue.popleft() diff --git a/lerobot/common/policies/act/backbone.py b/lerobot/common/policies/act/backbone.py deleted file mode 100644 index 6399d339..00000000 --- a/lerobot/common/policies/act/backbone.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import List - -import torch -import torchvision -from torch import nn -from torchvision.models._utils import IntermediateLayerGetter - -from .position_encoding import build_position_encoding -from .utils import NestedTensor, is_main_process - - -class FrozenBatchNorm2d(torch.nn.Module): - """ - BatchNorm2d where the batch statistics and the affine parameters are fixed. - - Copy-paste from torchvision.misc.ops with added eps before rqsrt, - without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101] - produce nans. - """ - - def __init__(self, n): - super().__init__() - self.register_buffer("weight", torch.ones(n)) - self.register_buffer("bias", torch.zeros(n)) - self.register_buffer("running_mean", torch.zeros(n)) - self.register_buffer("running_var", torch.ones(n)) - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - num_batches_tracked_key = prefix + "num_batches_tracked" - if num_batches_tracked_key in state_dict: - del state_dict[num_batches_tracked_key] - - super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - - def forward(self, x): - # move reshapes to the beginning - # to make it fuser-friendly - w = self.weight.reshape(1, -1, 1, 1) - b = self.bias.reshape(1, -1, 1, 1) - rv = self.running_var.reshape(1, -1, 1, 1) - rm = self.running_mean.reshape(1, -1, 1, 1) - eps = 1e-5 - scale = w * (rv + eps).rsqrt() - bias = b - rm * scale - return x * scale + bias - - -class BackboneBase(nn.Module): - def __init__( - self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool - ): - super().__init__() - # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this? - # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: - # parameter.requires_grad_(False) - if return_interm_layers: - return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} - else: - return_layers = {"layer4": "0"} - self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) - self.num_channels = num_channels - - def forward(self, tensor): - xs = self.body(tensor) - return xs - # out: Dict[str, NestedTensor] = {} - # for name, x in xs.items(): - # m = tensor_list.mask - # assert m is not None - # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] - # out[name] = NestedTensor(x, mask) - # return out - - -class Backbone(BackboneBase): - """ResNet backbone with frozen BatchNorm.""" - - def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool): - backbone = getattr(torchvision.models, name)( - replace_stride_with_dilation=[False, False, dilation], - pretrained=is_main_process(), - norm_layer=FrozenBatchNorm2d, - ) # pretrained # TODO do we want frozen batch_norm?? - num_channels = 512 if name in ("resnet18", "resnet34") else 2048 - super().__init__(backbone, train_backbone, num_channels, return_interm_layers) - - -class Joiner(nn.Sequential): - def __init__(self, backbone, position_embedding): - super().__init__(backbone, position_embedding) - - def forward(self, tensor_list: NestedTensor): - xs = self[0](tensor_list) - out: List[NestedTensor] = [] - pos = [] - for _, x in xs.items(): - out.append(x) - # position encoding - pos.append(self[1](x).to(x.dtype)) - - return out, pos - - -def build_backbone(args): - position_embedding = build_position_encoding(args) - train_backbone = args.lr_backbone > 0 - return_interm_layers = args.masks - backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) - model = Joiner(backbone, position_embedding) - model.num_channels = backbone.num_channels - return model diff --git a/lerobot/common/policies/act/detr_vae.py b/lerobot/common/policies/act/detr_vae.py deleted file mode 100644 index 0f2626f7..00000000 --- a/lerobot/common/policies/act/detr_vae.py +++ /dev/null @@ -1,212 +0,0 @@ -import numpy as np -import torch -from torch import nn -from torch.autograd import Variable - -from .backbone import build_backbone -from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer - - -def reparametrize(mu, logvar): - std = logvar.div(2).exp() - eps = Variable(std.data.new(std.size()).normal_()) - return mu + std * eps - - -def get_sinusoid_encoding_table(n_position, d_hid): - def get_position_angle_vec(position): - return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] - - sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) - sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i - sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 - - return torch.FloatTensor(sinusoid_table).unsqueeze(0) - - -class DETRVAE(nn.Module): - """This is the DETR module that performs object detection""" - - def __init__( - self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae - ): - """Initializes the model. - Parameters: - backbones: torch module of the backbone to be used. See backbone.py - transformer: torch module of the transformer architecture. See transformer.py - state_dim: robot state dimension of the environment - num_queries: number of object queries, ie detection slot. This is the maximal number of objects - DETR can detect in a single image. For COCO, we recommend 100 queries. - aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. - """ - super().__init__() - self.num_queries = num_queries - self.camera_names = camera_names - self.transformer = transformer - self.encoder = encoder - self.vae = vae - hidden_dim = transformer.d_model - self.action_head = nn.Linear(hidden_dim, action_dim) - self.is_pad_head = nn.Linear(hidden_dim, 1) - self.query_embed = nn.Embedding(num_queries, hidden_dim) - if backbones is not None: - self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) - self.backbones = nn.ModuleList(backbones) - self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) - else: - # input_dim = 14 + 7 # robot_state + env_state - self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) - # TODO(rcadene): understand what is env_state, and why it needs to be 7 - self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim) - self.pos = torch.nn.Embedding(2, hidden_dim) - self.backbones = None - - # encoder extra parameters - self.latent_dim = 32 # final size of latent z # TODO tune - self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding - self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding - self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding - self.latent_proj = nn.Linear( - hidden_dim, self.latent_dim * 2 - ) # project hidden state to latent std, var - self.register_buffer( - "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim) - ) # [CLS], qpos, a_seq - - # decoder extra parameters - self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding - self.additional_pos_embed = nn.Embedding( - 2, hidden_dim - ) # learned position embedding for proprio and latent - - def forward(self, qpos, image, env_state, actions=None, is_pad=None): - """ - qpos: batch, qpos_dim - image: batch, num_cam, channel, height, width - env_state: None - actions: batch, seq, action_dim - """ - is_training = actions is not None # train or val - bs, _ = qpos.shape - ### Obtain latent z from action sequence - if self.vae and is_training: - # project action sequence to embedding dim, and concat with a CLS token - action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) - qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) - qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) - cls_embed = self.cls_embed.weight # (1, hidden_dim) - cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim) - encoder_input = torch.cat( - [cls_embed, qpos_embed, action_embed], axis=1 - ) # (bs, seq+1, hidden_dim) - encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) - # do not mask cls token - # cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding - # is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) - # obtain position embedding - pos_embed = self.pos_table.clone().detach() - pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) - # query model - encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad) - encoder_output = encoder_output[0] # take cls output only - latent_info = self.latent_proj(encoder_output) - mu = latent_info[:, : self.latent_dim] - logvar = latent_info[:, self.latent_dim :] - latent_sample = reparametrize(mu, logvar) - latent_input = self.latent_out_proj(latent_sample) - else: - mu = logvar = None - latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device) - latent_input = self.latent_out_proj(latent_sample) - - if self.backbones is not None: - # Image observation features and position embeddings - all_cam_features = [] - all_cam_pos = [] - for cam_id, _ in enumerate(self.camera_names): - features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED - features = features[0] # take the last layer feature - pos = pos[0] - all_cam_features.append(self.input_proj(features)) - all_cam_pos.append(pos) - # proprioception features - proprio_input = self.input_proj_robot_state(qpos) - # fold camera dimension into width dimension - src = torch.cat(all_cam_features, axis=3) - pos = torch.cat(all_cam_pos, axis=3) - hs = self.transformer( - src, - None, - self.query_embed.weight, - pos, - latent_input, - proprio_input, - self.additional_pos_embed.weight, - )[0] - else: - qpos = self.input_proj_robot_state(qpos) - env_state = self.input_proj_env_state(env_state) - transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 - hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0] - a_hat = self.action_head(hs) - is_pad_hat = self.is_pad_head(hs) - return a_hat, is_pad_hat, [mu, logvar] - - -def mlp(input_dim, hidden_dim, output_dim, hidden_depth): - if hidden_depth == 0: - mods = [nn.Linear(input_dim, output_dim)] - else: - mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] - for _ in range(hidden_depth - 1): - mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] - mods.append(nn.Linear(hidden_dim, output_dim)) - trunk = nn.Sequential(*mods) - return trunk - - -def build_encoder(args): - d_model = args.hidden_dim # 256 - dropout = args.dropout # 0.1 - nhead = args.nheads # 8 - dim_feedforward = args.dim_feedforward # 2048 - num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder - normalize_before = args.pre_norm # False - activation = "relu" - - encoder_layer = TransformerEncoderLayer( - d_model, nhead, dim_feedforward, dropout, activation, normalize_before - ) - encoder_norm = nn.LayerNorm(d_model) if normalize_before else None - encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) - - return encoder - - -def build(args): - # From state - # backbone = None # from state for now, no need for conv nets - # From image - backbones = [] - backbone = build_backbone(args) - backbones.append(backbone) - - transformer = build_transformer(args) - - encoder = build_encoder(args) - - model = DETRVAE( - backbones, - transformer, - encoder, - state_dim=args.state_dim, - action_dim=args.action_dim, - num_queries=args.num_queries, - camera_names=args.camera_names, - vae=args.vae, - ) - - n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) - print("number of parameters: {:.2f}M".format(n_parameters / 1e6)) - - return model diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index ae4f7320..25b814ed 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -1,173 +1,200 @@ -import logging -import time +"""Action Chunking Transformer Policy +As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). +The majority of changes here involve removing unused code, unifying naming, and adding helpful comments. +""" + +import math +import time +from collections import deque +from itertools import chain +from typing import Callable + +import einops +import numpy as np import torch import torch.nn.functional as F # noqa: N812 +import torchvision import torchvision.transforms as transforms +from torch import Tensor, nn +from torchvision.models._utils import IntermediateLayerGetter +from torchvision.ops.misc import FrozenBatchNorm2d -from lerobot.common.policies.abstract import AbstractPolicy -from lerobot.common.policies.act.detr_vae import build from lerobot.common.utils import get_safe_torch_device -def build_act_model_and_optimizer(cfg): - model = build(cfg) +class ActionChunkingTransformerPolicy(nn.Module): + """ + Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost + Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) - param_dicts = [ - {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, - { - "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], - "lr": cfg.lr_backbone, - }, - ] - optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay) + Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. + - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the + model that encodes the target data (a sequence of actions), and the condition (the robot + joint-space). + - A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with + cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we + have an option to train this model without the variational objective (in which case we drop the + `vae_encoder` altogether, and nothing about this model has anything to do with a VAE). - return model, optimizer + Transformer + Used alone for inference + (acts as VAE decoder + during training) + ┌───────────────────────┐ + │ Outputs │ + │ ▲ │ + │ ┌─────►┌───────┐ │ + ┌──────┐ │ │ │Transf.│ │ + │ │ │ ├─────►│decoder│ │ + ┌────┴────┐ │ │ │ │ │ │ + │ │ │ │ ┌───┴───┬─►│ │ │ + │ VAE │ │ │ │ │ └───────┘ │ + │ encoder │ │ │ │Transf.│ │ + │ │ │ │ │encoder│ │ + └───▲─────┘ │ │ │ │ │ + │ │ │ └───▲───┘ │ + │ │ │ │ │ + inputs └─────┼─────┘ │ + │ │ + └───────────────────────┘ + """ - -def kl_divergence(mu, logvar): - batch_size = mu.size(0) - assert batch_size != 0 - if mu.data.ndimension() == 4: - mu = mu.view(mu.size(0), mu.size(1)) - if logvar.data.ndimension() == 4: - logvar = logvar.view(logvar.size(0), logvar.size(1)) - - klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) - total_kld = klds.sum(1).mean(0, True) - dimension_wise_kld = klds.mean(0) - mean_kld = klds.mean(1).mean(0, True) - - return total_kld, dimension_wise_kld, mean_kld - - -class ActionChunkingTransformerPolicy(AbstractPolicy): name = "act" + _multiple_obs_steps_not_handled_msg = ( + "ActionChunkingTransformerPolicy does not handle multiple observation steps." + ) - def __init__(self, cfg, device, n_action_steps=1): - super().__init__(n_action_steps) + def __init__(self, cfg, device): + """ + TODO(alexander-soare): Add documentation for all parameters once we have model configs established. + """ + super().__init__() + if getattr(cfg, "n_obs_steps", 1) != 1: + raise ValueError(self._multiple_obs_steps_not_handled_msg) self.cfg = cfg - self.n_action_steps = n_action_steps + self.n_action_steps = cfg.n_action_steps self.device = get_safe_torch_device(device) - self.model, self.optimizer = build_act_model_and_optimizer(cfg) - self.kl_weight = self.cfg.kl_weight - logging.info(f"KL Weight {self.kl_weight}") + self.camera_names = cfg.camera_names + self.use_vae = cfg.use_vae + self.horizon = cfg.horizon + self.d_model = cfg.d_model + + transformer_common_kwargs = dict( # noqa: C408 + d_model=self.d_model, + num_heads=cfg.num_heads, + dim_feedforward=cfg.dim_feedforward, + dropout=cfg.dropout, + activation=cfg.activation, + normalize_before=cfg.pre_norm, + ) + + # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. + # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). + if self.use_vae: + self.vae_encoder = _TransformerEncoder(num_layers=cfg.vae_enc_layers, **transformer_common_kwargs) + self.vae_encoder_cls_embed = nn.Embedding(1, self.d_model) + # Projection layer for joint-space configuration to hidden dimension. + self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model) + # Projection layer for action (joint-space target) to hidden dimension. + self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, self.d_model) + self.latent_dim = cfg.latent_dim + # Projection layer from the VAE encoder's output to the latent distribution's parameter space. + self.vae_encoder_latent_output_proj = nn.Linear(self.d_model, self.latent_dim * 2) + # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch + # dimension. + self.register_buffer( + "vae_encoder_pos_enc", + _create_sinusoidal_position_embedding(1 + 1 + self.horizon, self.d_model).unsqueeze(0), + ) + + # Backbone for image feature extraction. + self.image_normalizer = transforms.Normalize( + mean=cfg.image_normalization.mean, std=cfg.image_normalization.std + ) + backbone_model = getattr(torchvision.models, cfg.backbone)( + replace_stride_with_dilation=[False, False, cfg.dilation], + pretrained=cfg.pretrained_backbone, + norm_layer=FrozenBatchNorm2d, + ) + # Note: The forward method of this returns a dict: {"feature_map": output}. + self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + + # Transformer (acts as VAE decoder when training with the variational objective). + self.encoder = _TransformerEncoder(num_layers=cfg.enc_layers, **transformer_common_kwargs) + self.decoder = _TransformerDecoder(num_layers=cfg.dec_layers, **transformer_common_kwargs) + + # Transformer encoder input projections. The tokens will be structured like + # [latent, robot_state, image_feature_map_pixels]. + self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model) + self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.d_model) + self.encoder_img_feat_input_proj = nn.Conv2d( + backbone_model.fc.in_features, self.d_model, kernel_size=1 + ) + # Transformer encoder positional embeddings. + self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, self.d_model) + self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(self.d_model // 2) + + # Transformer decoder. + # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). + self.decoder_pos_embed = nn.Embedding(self.horizon, self.d_model) + + # Final action regression head on the output of the transformer's decoder. + self.action_head = nn.Linear(self.d_model, cfg.action_dim) + + self._reset_parameters() + + self._create_optimizer() self.to(self.device) - def update(self, replay_buffer, step): - del step - - start_time = time.time() - - self.train() - - num_slices = self.cfg.batch_size - batch_size = self.cfg.horizon * num_slices - - assert batch_size % self.cfg.horizon == 0 - assert batch_size % num_slices == 0 - - def process_batch(batch, horizon, num_slices): - # trajectory t = 64, horizon h = 16 - # (t h) ... -> t h ... - batch = batch.reshape(num_slices, horizon) - - image = batch["observation", "image", "top"] - image = image[:, 0] # first observation t=0 - # batch, num_cam, channel, height, width - image = image.unsqueeze(1) - assert image.ndim == 5 - image = image.float() - - state = batch["observation", "state"] - state = state[:, 0] # first observation t=0 - # batch, qpos_dim - assert state.ndim == 2 - - action = batch["action"] - # batch, seq, action_dim - assert action.ndim == 3 - assert action.shape[1] == horizon - - if self.cfg.n_obs_steps > 1: - raise NotImplementedError() - # # keep first n observations of the slice corresponding to t=[-1,0] - # image = image[:, : self.cfg.n_obs_steps] - # state = state[:, : self.cfg.n_obs_steps] - - out = { - "obs": { - "image": image.to(self.device, non_blocking=True), - "agent_pos": state.to(self.device, non_blocking=True), - }, - "action": action.to(self.device, non_blocking=True), - } - return out - - batch = replay_buffer.sample(batch_size) - batch = process_batch(batch, self.cfg.horizon, num_slices) - - data_s = time.time() - start_time - - loss = self.compute_loss(batch) - loss.backward() - - grad_norm = torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - self.cfg.grad_clip_norm, - error_if_nonfinite=False, + def _create_optimizer(self): + optimizer_params_dicts = [ + { + "params": [ + p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad + ] + }, + { + "params": [ + p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad + ], + "lr": self.cfg.lr_backbone, + }, + ] + self.optimizer = torch.optim.AdamW( + optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay ) - self.optimizer.step() - self.optimizer.zero_grad() - # self.lr_scheduler.step() + def _reset_parameters(self): + """Xavier-uniform initialization of the transformer parameters as in the original code.""" + for p in chain(self.encoder.parameters(), self.decoder.parameters()): + if p.dim() > 1: + nn.init.xavier_uniform_(p) - info = { - "loss": loss.item(), - "grad_norm": float(grad_norm), - # "lr": self.lr_scheduler.get_last_lr()[0], - "lr": self.cfg.lr, - "data_s": data_s, - "update_s": time.time() - start_time, - } + def reset(self): + """This should be called whenever the environment is reset.""" + if self.n_action_steps is not None: + self._action_queue = deque([], maxlen=self.n_action_steps) - return info - - def save(self, fp): - torch.save(self.state_dict(), fp) - - def load(self, fp): - d = torch.load(fp) - self.load_state_dict(d) - - def compute_loss(self, batch): - loss_dict = self._forward( - qpos=batch["obs"]["agent_pos"], - image=batch["obs"]["image"], - actions=batch["action"], - ) - loss = loss_dict["loss"] - return loss + def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor: + """ + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + if len(self._action_queue) == 0: + # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape + # (n_action_steps, batch_size, *), hence the transpose. + self._action_queue.extend(self.select_actions(batch).transpose(0, 1)) + return self._action_queue.popleft() @torch.no_grad() - def select_actions(self, observation, step_count): - if observation["image"].shape[0] != 1: - raise NotImplementedError("Batch size > 1 not handled") - - # TODO(rcadene): remove unused step_count - del step_count - + def select_actions(self, batch: dict[str, Tensor]) -> Tensor: + """Use the action chunking transformer to generate a sequence of actions.""" self.eval() + self._preprocess_batch(batch, add_obs_steps_dim=True) - # TODO(rcadene): remove hack - # add 1 camera dimension - observation["image", "top"] = observation["image", "top"].unsqueeze(1) - - obs_dict = { - "image": observation["image", "top"], - "agent_pos": observation["state"], - } - action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"]) + action = self.forward(batch, return_loss=False) if self.cfg.temporal_agg: # TODO(rcadene): implement temporal aggregation @@ -182,35 +209,470 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) - # take first predicted action or n first actions - action = action[: self.n_action_steps] - return action + return action[: self.n_action_steps] - def _forward(self, qpos, image, actions=None, is_pad=None): - env_state = None - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - image = normalize(image) + def __call__(self, *args, **kwargs) -> dict: + # TODO(now): Temporary bridge until we know what to do about the `update` method. + return self.update(*args, **kwargs) - is_training = actions is not None - if is_training: # training time - actions = actions[:, : self.model.num_queries] - if is_pad is not None: - is_pad = is_pad[:, : self.model.num_queries] + def _preprocess_batch( + self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False + ) -> dict[str, Tensor]: + """ + This function expects `batch` to have (at least): + { + "observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration). + "observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images. + "action": (B, H, J) tensor of actions (positional target for robot joint configuration) + "action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds. + } + """ + if add_obs_steps_dim: + # Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now, + # this just amounts to an unsqueeze. + for k in batch: + if k.startswith("observation."): + batch[k] = batch[k].unsqueeze(1) - a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad) + if batch["observation.state"].shape[1] != 1: + raise ValueError(self._multiple_obs_steps_not_handled_msg) + batch["observation.state"] = batch["observation.state"].squeeze(1) + # TODO(alexander-soare): generalize this to multiple images. + assert ( + sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1 + ), "ACT only handles one image for now." + # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get + # the image index dimension. - all_l1 = F.l1_loss(actions, a_hat, reduction="none") - l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean() + def update(self, batch, *_, **__) -> dict: + start_time = time.time() + self._preprocess_batch(batch) + + self.train() + + num_slices = self.cfg.batch_size + batch_size = self.cfg.horizon * num_slices + + assert batch_size % self.cfg.horizon == 0 + assert batch_size % num_slices == 0 + + loss = self.forward(batch, return_loss=True)["loss"] + loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + self.parameters(), + self.cfg.grad_clip_norm, + error_if_nonfinite=False, + ) + + self.optimizer.step() + self.optimizer.zero_grad() + + info = { + "loss": loss.item(), + "grad_norm": float(grad_norm), + "lr": self.cfg.lr, + "update_s": time.time() - start_time, + } + + return info + + def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor: + images = self.image_normalizer(batch["observation.images.top"]) + + if return_loss: # training time + actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward( + batch["observation.state"], images, batch["action"] + ) + + l1_loss = ( + F.l1_loss(batch["action"], actions_hat, reduction="none") + * ~batch["action_is_pad"].unsqueeze(-1) + ).mean() loss_dict = {} - loss_dict["l1"] = l1 - if self.cfg.vae: - total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) - loss_dict["kl"] = total_kld[0] - loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight + loss_dict["l1"] = l1_loss + if self.cfg.use_vae: + # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for + # each dimension independently, we sum over the latent dimension to get the total + # KL-divergence per batch element, then take the mean over the batch. + # (See App. B of https://arxiv.org/abs/1312.6114 for more details). + mean_kld = ( + (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() + ) + loss_dict["kl"] = mean_kld + loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight else: loss_dict["loss"] = loss_dict["l1"] return loss_dict else: - action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior + action, _ = self._forward(batch["observation.state"], images) return action + + def _forward( + self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None + ) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]: + """ + Args: + robot_state: (B, J) batch of robot joint configurations. + image: (B, N, C, H, W) batch of N camera frames. + actions: (B, S, A) batch of actions from the target dataset which must be provided if the + VAE is enabled and the model is in training mode. + Returns: + (B, S, A) batch of action sequences + Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the + latent dimension. + """ + if self.use_vae and self.training: + assert ( + actions is not None + ), "actions must be provided when using the variational objective in training mode." + + batch_size = robot_state.shape[0] + + # Prepare the latent for input to the transformer encoder. + if self.use_vae and actions is not None: + # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. + cls_embed = einops.repeat( + self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size + ) # (B, 1, D) + robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D) + action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D) + vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) + + # Prepare fixed positional embedding. + # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. + pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) + + # Forward pass through VAE encoder to get the latent PDF parameters. + cls_token_out = self.vae_encoder( + vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) + )[0] # select the class token, with shape (B, D) + latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) + mu = latent_pdf_params[:, : self.latent_dim] + # This is 2log(sigma). Done this way to match the original implementation. + log_sigma_x2 = latent_pdf_params[:, self.latent_dim :] + + # Sample the latent with the reparameterization trick. + latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) + else: + # When not using the VAE encoder, we set the latent to be all zeros. + mu = log_sigma_x2 = None + latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( + robot_state.device + ) + + # Prepare all other transformer encoder inputs. + # Camera observation features and positional embeddings. + all_cam_features = [] + all_cam_pos_embeds = [] + for cam_id, _ in enumerate(self.camera_names): + cam_features = self.backbone(image[:, cam_id])["feature_map"] + cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) + cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) + all_cam_features.append(cam_features) + all_cam_pos_embeds.append(cam_pos_embed) + # Concatenate camera observation feature maps and positional embeddings along the width dimension. + encoder_in = torch.cat(all_cam_features, axis=3) + cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3) + + # Get positional embeddings for robot state and latent. + robot_state_embed = self.encoder_robot_state_input_proj(robot_state) + latent_embed = self.encoder_latent_input_proj(latent_sample) + + # Stack encoder input and positional embeddings moving to (S, B, C). + encoder_in = torch.cat( + [ + torch.stack([latent_embed, robot_state_embed], axis=0), + encoder_in.flatten(2).permute(2, 0, 1), + ] + ) + pos_embed = torch.cat( + [ + self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1), + cam_pos_embed.flatten(2).permute(2, 0, 1), + ], + axis=0, + ) + + # Forward pass through the transformer modules. + encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) + decoder_in = torch.zeros( + (self.horizon, batch_size, self.d_model), dtype=pos_embed.dtype, device=pos_embed.device + ) + decoder_out = self.decoder( + decoder_in, + encoder_out, + encoder_pos_embed=pos_embed, + decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1), + ) + + # Move back to (B, S, C). + decoder_out = decoder_out.transpose(0, 1) + + actions = self.action_head(decoder_out) + + return actions, (mu, log_sigma_x2) + + def save(self, fp): + torch.save(self.state_dict(), fp) + + def load(self, fp): + d = torch.load(fp) + self.load_state_dict(d) + + +class _TransformerEncoder(nn.Module): + """Convenience module for running multiple encoder layers, maybe followed by normalization.""" + + def __init__(self, num_layers: int, **encoder_layer_kwargs: dict): + super().__init__() + self.layers = nn.ModuleList( + [_TransformerEncoderLayer(**encoder_layer_kwargs) for _ in range(num_layers)] + ) + self.norm = ( + nn.LayerNorm(encoder_layer_kwargs["d_model"]) + if encoder_layer_kwargs["normalize_before"] + else nn.Identity() + ) + + def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor: + for layer in self.layers: + x = layer(x, pos_embed=pos_embed) + x = self.norm(x) + return x + + +class _TransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model: int, + num_heads: int, + dim_feedforward: int, + dropout: float, + activation: str, + normalize_before: bool, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout) + + # Feed forward layers. + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def forward(self, x, pos_embed: Tensor | None = None) -> Tensor: + skip = x + if self.normalize_before: + x = self.norm1(x) + q = k = x if pos_embed is None else x + pos_embed + x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights + x = skip + self.dropout1(x) + if self.normalize_before: + skip = x + x = self.norm2(x) + else: + x = self.norm1(x) + skip = x + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = skip + self.dropout2(x) + if not self.normalize_before: + x = self.norm2(x) + return x + + +class _TransformerDecoder(nn.Module): + def __init__(self, num_layers: int, **decoder_layer_kwargs): + """Convenience module for running multiple decoder layers followed by normalization.""" + super().__init__() + self.layers = nn.ModuleList( + [_TransformerDecoderLayer(**decoder_layer_kwargs) for _ in range(num_layers)] + ) + self.num_layers = num_layers + self.norm = nn.LayerNorm(decoder_layer_kwargs["d_model"]) + + def forward( + self, + x: Tensor, + encoder_out: Tensor, + decoder_pos_embed: Tensor | None = None, + encoder_pos_embed: Tensor | None = None, + ) -> Tensor: + for layer in self.layers: + x = layer( + x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed + ) + if self.norm is not None: + x = self.norm(x) + return x + + +class _TransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model: int, + num_heads: int, + dim_feedforward: int, + dropout: float, + activation: str, + normalize_before: bool, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout) + + # Feed forward layers. + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: + return tensor if pos_embed is None else tensor + pos_embed + + def forward( + self, + x: Tensor, + encoder_out: Tensor, + decoder_pos_embed: Tensor | None = None, + encoder_pos_embed: Tensor | None = None, + ) -> Tensor: + """ + Args: + x: (Decoder Sequence, Batch, Channel) tensor of input tokens. + encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are + cross-attending with. + decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder). + encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder). + Returns: + (DS, B, C) tensor of decoder output features. + """ + skip = x + if self.normalize_before: + x = self.norm1(x) + q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) + x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights + x = skip + self.dropout1(x) + if self.normalize_before: + skip = x + x = self.norm2(x) + else: + x = self.norm1(x) + skip = x + x = self.multihead_attn( + query=self.maybe_add_pos_embed(x, decoder_pos_embed), + key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed), + value=encoder_out, + )[0] # select just the output, not the attention weights + x = skip + self.dropout2(x) + if self.normalize_before: + skip = x + x = self.norm3(x) + else: + x = self.norm2(x) + skip = x + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = skip + self.dropout3(x) + if not self.normalize_before: + x = self.norm3(x) + return x + + +def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor: + """1D sinusoidal positional embeddings as in Attention is All You Need. + + Args: + num_positions: Number of token positions required. + Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension). + + """ + + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + return torch.from_numpy(sinusoid_table).float() + + +class _SinusoidalPositionEmbedding2D(nn.Module): + """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need. + + The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H + for the vertical direction, and 1/W for the horizontal direction. + """ + + def __init__(self, dimension: int): + """ + Args: + dimension: The desired dimension of the embeddings. + """ + super().__init__() + self.dimension = dimension + self._two_pi = 2 * math.pi + self._eps = 1e-6 + # Inverse "common ratio" for the geometric progression in sinusoid frequencies. + self._temperature = 10000 + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for. + Returns: + A (1, C, H, W) batch of corresponding sinusoidal positional embeddings. + """ + not_mask = torch.ones_like(x[0, :1]) # (1, H, W) + # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations + # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code. + y_range = not_mask.cumsum(1, dtype=torch.float32) + x_range = not_mask.cumsum(2, dtype=torch.float32) + + # "Normalize" the position index such that it ranges in [0, 2π]. + # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range + # are non-zero by construction. This is an artifact of the original code. + y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi + x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi + + inverse_frequency = self._temperature ** ( + 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension + ) + + x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) + y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) + + # Note: this stack then flatten operation results in interleaved sine and cosine terms. + # pos_embed_x and pos_embed_y are (1, H, W, C // 2). + pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W) + + return pos_embed + + +def _get_activation_fn(activation: str) -> Callable: + """Return an activation function given a string.""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") diff --git a/lerobot/common/policies/act/position_encoding.py b/lerobot/common/policies/act/position_encoding.py deleted file mode 100644 index 63bb4840..00000000 --- a/lerobot/common/policies/act/position_encoding.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -Various positional encodings for the transformer. -""" - -import math - -import torch -from torch import nn - -from .utils import NestedTensor - - -class PositionEmbeddingSine(nn.Module): - """ - This is a more standard version of the position embedding, very similar to the one - used by the Attention is all you need paper, generalized to work on images. - """ - - def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): - super().__init__() - self.num_pos_feats = num_pos_feats - self.temperature = temperature - self.normalize = normalize - if scale is not None and normalize is False: - raise ValueError("normalize should be True if scale is passed") - if scale is None: - scale = 2 * math.pi - self.scale = scale - - def forward(self, tensor): - x = tensor - # mask = tensor_list.mask - # assert mask is not None - # not_mask = ~mask - - not_mask = torch.ones_like(x[0, [0]]) - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) - if self.normalize: - eps = 1e-6 - y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale - x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) - - pos_x = x_embed[:, :, :, None] / dim_t - pos_y = y_embed[:, :, :, None] / dim_t - pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) - pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) - pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - return pos - - -class PositionEmbeddingLearned(nn.Module): - """ - Absolute pos embedding, learned. - """ - - def __init__(self, num_pos_feats=256): - super().__init__() - self.row_embed = nn.Embedding(50, num_pos_feats) - self.col_embed = nn.Embedding(50, num_pos_feats) - self.reset_parameters() - - def reset_parameters(self): - nn.init.uniform_(self.row_embed.weight) - nn.init.uniform_(self.col_embed.weight) - - def forward(self, tensor_list: NestedTensor): - x = tensor_list.tensors - h, w = x.shape[-2:] - i = torch.arange(w, device=x.device) - j = torch.arange(h, device=x.device) - x_emb = self.col_embed(i) - y_emb = self.row_embed(j) - pos = ( - torch.cat( - [ - x_emb.unsqueeze(0).repeat(h, 1, 1), - y_emb.unsqueeze(1).repeat(1, w, 1), - ], - dim=-1, - ) - .permute(2, 0, 1) - .unsqueeze(0) - .repeat(x.shape[0], 1, 1, 1) - ) - return pos - - -def build_position_encoding(args): - n_steps = args.hidden_dim // 2 - if args.position_embedding in ("v2", "sine"): - # TODO find a better way of exposing other arguments - position_embedding = PositionEmbeddingSine(n_steps, normalize=True) - elif args.position_embedding in ("v3", "learned"): - position_embedding = PositionEmbeddingLearned(n_steps) - else: - raise ValueError(f"not supported {args.position_embedding}") - - return position_embedding diff --git a/lerobot/common/policies/act/transformer.py b/lerobot/common/policies/act/transformer.py deleted file mode 100644 index 20cfc815..00000000 --- a/lerobot/common/policies/act/transformer.py +++ /dev/null @@ -1,371 +0,0 @@ -""" -DETR Transformer class. - -Copy-paste from torch.nn.Transformer with modifications: - * positional encodings are passed in MHattention - * extra LN at the end of encoder is removed - * decoder returns a stack of activations from all decoding layers -""" - -import copy -from typing import Optional - -import torch -import torch.nn.functional as F # noqa: N812 -from torch import Tensor, nn - - -class Transformer(nn.Module): - def __init__( - self, - d_model=512, - nhead=8, - num_encoder_layers=6, - num_decoder_layers=6, - dim_feedforward=2048, - dropout=0.1, - activation="relu", - normalize_before=False, - return_intermediate_dec=False, - ): - super().__init__() - - encoder_layer = TransformerEncoderLayer( - d_model, nhead, dim_feedforward, dropout, activation, normalize_before - ) - encoder_norm = nn.LayerNorm(d_model) if normalize_before else None - self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) - - decoder_layer = TransformerDecoderLayer( - d_model, nhead, dim_feedforward, dropout, activation, normalize_before - ) - decoder_norm = nn.LayerNorm(d_model) - self.decoder = TransformerDecoder( - decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec - ) - - self._reset_parameters() - - self.d_model = d_model - self.nhead = nhead - - def _reset_parameters(self): - for p in self.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - - def forward( - self, - src, - mask, - query_embed, - pos_embed, - latent_input=None, - proprio_input=None, - additional_pos_embed=None, - ): - # TODO flatten only when input has H and W - if len(src.shape) == 4: # has H and W - # flatten NxCxHxW to HWxNxC - bs, c, h, w = src.shape - src = src.flatten(2).permute(2, 0, 1) - pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) - query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) - # mask = mask.flatten(1) - - additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim - pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) - - addition_input = torch.stack([latent_input, proprio_input], axis=0) - src = torch.cat([addition_input, src], axis=0) - else: - assert len(src.shape) == 3 - # flatten NxHWxC to HWxNxC - bs, hw, c = src.shape - src = src.permute(1, 0, 2) - pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1) - query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) - - tgt = torch.zeros_like(query_embed) - memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) - hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed) - hs = hs.transpose(1, 2) - return hs - - -class TransformerEncoder(nn.Module): - def __init__(self, encoder_layer, num_layers, norm=None): - super().__init__() - self.layers = _get_clones(encoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward( - self, - src, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - ): - output = src - - for layer in self.layers: - output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos) - - if self.norm is not None: - output = self.norm(output) - - return output - - -class TransformerDecoder(nn.Module): - def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): - super().__init__() - self.layers = _get_clones(decoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - self.return_intermediate = return_intermediate - - def forward( - self, - tgt, - memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None, - ): - output = tgt - - intermediate = [] - - for layer in self.layers: - output = layer( - output, - memory, - tgt_mask=tgt_mask, - memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - pos=pos, - query_pos=query_pos, - ) - if self.return_intermediate: - intermediate.append(self.norm(output)) - - if self.norm is not None: - output = self.norm(output) - if self.return_intermediate: - intermediate.pop() - intermediate.append(output) - - if self.return_intermediate: - return torch.stack(intermediate) - - return output.unsqueeze(0) - - -class TransformerEncoderLayer(nn.Module): - def __init__( - self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False - ): - super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - self.normalize_before = normalize_before - - def with_pos_embed(self, tensor, pos: Optional[Tensor]): - return tensor if pos is None else tensor + pos - - def forward_post( - self, - src, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - ): - q = k = self.with_pos_embed(src, pos) - src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] - src = src + self.dropout1(src2) - src = self.norm1(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = src + self.dropout2(src2) - src = self.norm2(src) - return src - - def forward_pre( - self, - src, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - ): - src2 = self.norm1(src) - q = k = self.with_pos_embed(src2, pos) - src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] - src = src + self.dropout1(src2) - src2 = self.norm2(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) - src = src + self.dropout2(src2) - return src - - def forward( - self, - src, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - ): - if self.normalize_before: - return self.forward_pre(src, src_mask, src_key_padding_mask, pos) - return self.forward_post(src, src_mask, src_key_padding_mask, pos) - - -class TransformerDecoderLayer(nn.Module): - def __init__( - self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False - ): - super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - self.normalize_before = normalize_before - - def with_pos_embed(self, tensor, pos: Optional[Tensor]): - return tensor if pos is None else tensor + pos - - def forward_post( - self, - tgt, - memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None, - ): - q = k = self.with_pos_embed(tgt, query_pos) - tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] - tgt = tgt + self.dropout1(tgt2) - tgt = self.norm1(tgt) - tgt2 = self.multihead_attn( - query=self.with_pos_embed(tgt, query_pos), - key=self.with_pos_embed(memory, pos), - value=memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, - )[0] - tgt = tgt + self.dropout2(tgt2) - tgt = self.norm2(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) - tgt = tgt + self.dropout3(tgt2) - tgt = self.norm3(tgt) - return tgt - - def forward_pre( - self, - tgt, - memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None, - ): - tgt2 = self.norm1(tgt) - q = k = self.with_pos_embed(tgt2, query_pos) - tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] - tgt = tgt + self.dropout1(tgt2) - tgt2 = self.norm2(tgt) - tgt2 = self.multihead_attn( - query=self.with_pos_embed(tgt2, query_pos), - key=self.with_pos_embed(memory, pos), - value=memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, - )[0] - tgt = tgt + self.dropout2(tgt2) - tgt2 = self.norm3(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) - tgt = tgt + self.dropout3(tgt2) - return tgt - - def forward( - self, - tgt, - memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None, - ): - if self.normalize_before: - return self.forward_pre( - tgt, - memory, - tgt_mask, - memory_mask, - tgt_key_padding_mask, - memory_key_padding_mask, - pos, - query_pos, - ) - return self.forward_post( - tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos - ) - - -def _get_clones(module, n): - return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) - - -def build_transformer(args): - return Transformer( - d_model=args.hidden_dim, - dropout=args.dropout, - nhead=args.nheads, - dim_feedforward=args.dim_feedforward, - num_encoder_layers=args.enc_layers, - num_decoder_layers=args.dec_layers, - normalize_before=args.pre_norm, - return_intermediate_dec=True, - ) - - -def _get_activation_fn(activation): - """Return an activation function given a string""" - if activation == "relu": - return F.relu - if activation == "gelu": - return F.gelu - if activation == "glu": - return F.glu - raise RuntimeError(f"activation should be relu/gelu, not {activation}.") diff --git a/lerobot/common/policies/act/utils.py b/lerobot/common/policies/act/utils.py deleted file mode 100644 index 0d935839..00000000 --- a/lerobot/common/policies/act/utils.py +++ /dev/null @@ -1,478 +0,0 @@ -""" -Misc functions, including distributed helpers. - -Mostly copy-paste from torchvision references. -""" - -import datetime -import os -import pickle -import subprocess -import time -from collections import defaultdict, deque -from typing import List, Optional - -import torch -import torch.distributed as dist - -# needed due to empty tensor bug in pytorch and torchvision 0.5 -import torchvision -from packaging import version -from torch import Tensor - -if version.parse(torchvision.__version__) < version.parse("0.7"): - from torchvision.ops import _new_empty_tensor - from torchvision.ops.misc import _output_size - - -class SmoothedValue: - """Track a series of values and provide access to smoothed values over a - window or the global series average. - """ - - def __init__(self, window_size=20, fmt=None): - if fmt is None: - fmt = "{median:.4f} ({global_avg:.4f})" - self.deque = deque(maxlen=window_size) - self.total = 0.0 - self.count = 0 - self.fmt = fmt - - def update(self, value, n=1): - self.deque.append(value) - self.count += n - self.total += value * n - - def synchronize_between_processes(self): - """ - Warning: does not synchronize the deque! - """ - if not is_dist_avail_and_initialized(): - return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") - dist.barrier() - dist.all_reduce(t) - t = t.tolist() - self.count = int(t[0]) - self.total = t[1] - - @property - def median(self): - d = torch.tensor(list(self.deque)) - return d.median().item() - - @property - def avg(self): - d = torch.tensor(list(self.deque), dtype=torch.float32) - return d.mean().item() - - @property - def global_avg(self): - return self.total / self.count - - @property - def max(self): - return max(self.deque) - - @property - def value(self): - return self.deque[-1] - - def __str__(self): - return self.fmt.format( - median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value - ) - - -def all_gather(data): - """ - Run all_gather on arbitrary picklable data (not necessarily tensors) - Args: - data: any picklable object - Returns: - list[data]: list of data gathered from each rank - """ - world_size = get_world_size() - if world_size == 1: - return [data] - - # serialized to a Tensor - buffer = pickle.dumps(data) - storage = torch.ByteStorage.from_buffer(buffer) - tensor = torch.ByteTensor(storage).to("cuda") - - # obtain Tensor size of each rank - local_size = torch.tensor([tensor.numel()], device="cuda") - size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] - dist.all_gather(size_list, local_size) - size_list = [int(size.item()) for size in size_list] - max_size = max(size_list) - - # receiving Tensor from all ranks - # we pad the tensor because torch all_gather does not support - # gathering tensors of different shapes - tensor_list = [] - for _ in size_list: - tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) - if local_size != max_size: - padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") - tensor = torch.cat((tensor, padding), dim=0) - dist.all_gather(tensor_list, tensor) - - data_list = [] - for size, tensor in zip(size_list, tensor_list, strict=False): - buffer = tensor.cpu().numpy().tobytes()[:size] - data_list.append(pickle.loads(buffer)) - - return data_list - - -def reduce_dict(input_dict, average=True): - """ - Args: - input_dict (dict): all the values will be reduced - average (bool): whether to do average or sum - Reduce the values in the dictionary from all processes so that all processes - have the averaged results. Returns a dict with the same fields as - input_dict, after reduction. - """ - world_size = get_world_size() - if world_size < 2: - return input_dict - with torch.no_grad(): - names = [] - values = [] - # sort the keys so that they are consistent across processes - for k in sorted(input_dict.keys()): - names.append(k) - values.append(input_dict[k]) - values = torch.stack(values, dim=0) - dist.all_reduce(values) - if average: - values /= world_size - reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416 - return reduced_dict - - -class MetricLogger: - def __init__(self, delimiter="\t"): - self.meters = defaultdict(SmoothedValue) - self.delimiter = delimiter - - def update(self, **kwargs): - for k, v in kwargs.items(): - if isinstance(v, torch.Tensor): - v = v.item() - assert isinstance(v, (float, int)) - self.meters[k].update(v) - - def __getattr__(self, attr): - if attr in self.meters: - return self.meters[attr] - if attr in self.__dict__: - return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) - - def __str__(self): - loss_str = [] - for name, meter in self.meters.items(): - loss_str.append("{}: {}".format(name, str(meter))) - return self.delimiter.join(loss_str) - - def synchronize_between_processes(self): - for meter in self.meters.values(): - meter.synchronize_between_processes() - - def add_meter(self, name, meter): - self.meters[name] = meter - - def log_every(self, iterable, print_freq, header=None): - if not header: - header = "" - start_time = time.time() - end = time.time() - iter_time = SmoothedValue(fmt="{avg:.4f}") - data_time = SmoothedValue(fmt="{avg:.4f}") - space_fmt = ":" + str(len(str(len(iterable)))) + "d" - if torch.cuda.is_available(): - log_msg = self.delimiter.join( - [ - header, - "[{0" + space_fmt + "}/{1}]", - "eta: {eta}", - "{meters}", - "time: {time}", - "data: {data}", - "max mem: {memory:.0f}", - ] - ) - else: - log_msg = self.delimiter.join( - [ - header, - "[{0" + space_fmt + "}/{1}]", - "eta: {eta}", - "{meters}", - "time: {time}", - "data: {data}", - ] - ) - mega_b = 1024.0 * 1024.0 - for i, obj in enumerate(iterable): - data_time.update(time.time() - end) - yield obj - iter_time.update(time.time() - end) - if i % print_freq == 0 or i == len(iterable) - 1: - eta_seconds = iter_time.global_avg * (len(iterable) - i) - eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) - if torch.cuda.is_available(): - print( - log_msg.format( - i, - len(iterable), - eta=eta_string, - meters=str(self), - time=str(iter_time), - data=str(data_time), - memory=torch.cuda.max_memory_allocated() / mega_b, - ) - ) - else: - print( - log_msg.format( - i, - len(iterable), - eta=eta_string, - meters=str(self), - time=str(iter_time), - data=str(data_time), - ) - ) - end = time.time() - total_time = time.time() - start_time - total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable))) - - -def get_sha(): - cwd = os.path.dirname(os.path.abspath(__file__)) - - def _run(command): - return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() - - sha = "N/A" - diff = "clean" - branch = "N/A" - try: - sha = _run(["git", "rev-parse", "HEAD"]) - subprocess.check_output(["git", "diff"], cwd=cwd) - diff = _run(["git", "diff-index", "HEAD"]) - diff = "has uncommited changes" if diff else "clean" - branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) - except Exception: - pass - message = f"sha: {sha}, status: {diff}, branch: {branch}" - return message - - -def collate_fn(batch): - batch = list(zip(*batch, strict=False)) - batch[0] = nested_tensor_from_tensor_list(batch[0]) - return tuple(batch) - - -def _max_by_axis(the_list): - # type: (List[List[int]]) -> List[int] - maxes = the_list[0] - for sublist in the_list[1:]: - for index, item in enumerate(sublist): - maxes[index] = max(maxes[index], item) - return maxes - - -class NestedTensor: - def __init__(self, tensors, mask: Optional[Tensor]): - self.tensors = tensors - self.mask = mask - - def to(self, device): - # type: (Device) -> NestedTensor # noqa - cast_tensor = self.tensors.to(device) - mask = self.mask - if mask is not None: - assert mask is not None - cast_mask = mask.to(device) - else: - cast_mask = None - return NestedTensor(cast_tensor, cast_mask) - - def decompose(self): - return self.tensors, self.mask - - def __repr__(self): - return str(self.tensors) - - -def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): - # TODO make this more general - if tensor_list[0].ndim == 3: - if torchvision._is_tracing(): - # nested_tensor_from_tensor_list() does not export well to ONNX - # call _onnx_nested_tensor_from_tensor_list() instead - return _onnx_nested_tensor_from_tensor_list(tensor_list) - - # TODO make it support different-sized images - max_size = _max_by_axis([list(img.shape) for img in tensor_list]) - # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) - batch_shape = [len(tensor_list)] + max_size - b, c, h, w = batch_shape - dtype = tensor_list[0].dtype - device = tensor_list[0].device - tensor = torch.zeros(batch_shape, dtype=dtype, device=device) - mask = torch.ones((b, h, w), dtype=torch.bool, device=device) - for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False): - pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - m[: img.shape[1], : img.shape[2]] = False - else: - raise ValueError("not supported") - return NestedTensor(tensor, mask) - - -# _onnx_nested_tensor_from_tensor_list() is an implementation of -# nested_tensor_from_tensor_list() that is supported by ONNX tracing. -@torch.jit.unused -def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: - max_size = [] - for i in range(tensor_list[0].dim()): - max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to( - torch.int64 - ) - max_size.append(max_size_i) - max_size = tuple(max_size) - - # work around for - # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - # m[: img.shape[1], :img.shape[2]] = False - # which is not yet supported in onnx - padded_imgs = [] - padded_masks = [] - for img in tensor_list: - padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)] - padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) - padded_imgs.append(padded_img) - - m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) - padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) - padded_masks.append(padded_mask.to(torch.bool)) - - tensor = torch.stack(padded_imgs) - mask = torch.stack(padded_masks) - - return NestedTensor(tensor, mask=mask) - - -def setup_for_distributed(is_master): - """ - This function disables printing when not in master process - """ - import builtins as __builtin__ - - builtin_print = __builtin__.print - - def print(*args, **kwargs): - force = kwargs.pop("force", False) - if is_master or force: - builtin_print(*args, **kwargs) - - __builtin__.print = print - - -def is_dist_avail_and_initialized(): - if not dist.is_available(): - return False - if not dist.is_initialized(): - return False - return True - - -def get_world_size(): - if not is_dist_avail_and_initialized(): - return 1 - return dist.get_world_size() - - -def get_rank(): - if not is_dist_avail_and_initialized(): - return 0 - return dist.get_rank() - - -def is_main_process(): - return get_rank() == 0 - - -def save_on_master(*args, **kwargs): - if is_main_process(): - torch.save(*args, **kwargs) - - -def init_distributed_mode(args): - if "RANK" in os.environ and "WORLD_SIZE" in os.environ: - args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ["WORLD_SIZE"]) - args.gpu = int(os.environ["LOCAL_RANK"]) - elif "SLURM_PROCID" in os.environ: - args.rank = int(os.environ["SLURM_PROCID"]) - args.gpu = args.rank % torch.cuda.device_count() - else: - print("Not using distributed mode") - args.distributed = False - return - - args.distributed = True - - torch.cuda.set_device(args.gpu) - args.dist_backend = "nccl" - print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) - torch.distributed.init_process_group( - backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank - ) - torch.distributed.barrier() - setup_for_distributed(args.rank == 0) - - -@torch.no_grad() -def accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" - if target.numel() == 0: - return [torch.zeros([], device=output.device)] - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) - res.append(correct_k.mul_(100.0 / batch_size)) - return res - - -def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): - # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor - """ - Equivalent to nn.functional.interpolate, but with support for empty batch sizes. - This will eventually be supported natively by PyTorch, and this - class can go away. - """ - if version.parse(torchvision.__version__) < version.parse("0.7"): - if input.numel() > 0: - return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) - - output_shape = _output_size(2, input, size, scale_factor) - output_shape = list(input.shape[:-2]) + list(output_shape) - return _new_empty_tensor(input, output_shape) - else: - return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py index 7719fdde..f7432db3 100644 --- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py +++ b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py @@ -244,8 +244,10 @@ class DiffusionUnetImagePolicy(BaseImagePolicy): return result def compute_loss(self, batch): - assert "valid_mask" not in batch - nobs = batch["obs"] + nobs = { + "image": batch["observation.image"], + "agent_pos": batch["observation.state"], + } nactions = batch["action"] batch_size = nactions.shape[0] horizon = nactions.shape[1] @@ -303,6 +305,11 @@ class DiffusionUnetImagePolicy(BaseImagePolicy): loss = F.mse_loss(pred, target, reduction="none") loss = loss * loss_mask.type(loss.dtype) - loss = reduce(loss, "b ... -> b (...)", "mean") + + if "action_is_pad" in batch: + in_episode_bound = ~batch["action_is_pad"] + loss = loss * in_episode_bound[:, :, None].type(loss.dtype) + + loss = reduce(loss, "b t c -> b", "mean", b=batch_size) loss = loss.mean() return loss diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 82f39b28..9785358b 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -1,18 +1,20 @@ import copy import logging import time +from collections import deque import hydra import torch +from torch import nn -from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder +from lerobot.common.policies.utils import populate_queues from lerobot.common.utils import get_safe_torch_device -class DiffusionPolicy(AbstractPolicy): +class DiffusionPolicy(nn.Module): name = "diffusion" def __init__( @@ -38,8 +40,12 @@ class DiffusionPolicy(AbstractPolicy): # parameters passed to step **kwargs, ): - super().__init__(n_action_steps) + super().__init__() self.cfg = cfg + self.n_obs_steps = n_obs_steps + self.n_action_steps = n_action_steps + # queues are populated during rollout of the policy, they contain the n latest observations and actions + self._queues = None noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape) @@ -100,75 +106,51 @@ class DiffusionPolicy(AbstractPolicy): last_epoch=self.global_step - 1, ) + def reset(self): + """ + Clear observation and action queues. Should be called on `env.reset()` + """ + self._queues = { + "observation.image": deque(maxlen=self.n_obs_steps), + "observation.state": deque(maxlen=self.n_obs_steps), + "action": deque(maxlen=self.n_action_steps), + } + @torch.no_grad() - def select_actions(self, observation, step_count): + def select_action(self, batch, step): """ Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights. """ - # TODO(rcadene): remove unused step_count - del step_count + # TODO(rcadene): remove unused step + del step + assert "observation.image" in batch + assert "observation.state" in batch + assert len(batch) == 2 - obs_dict = { - "image": observation["image"], - "agent_pos": observation["state"], - } - if self.training: - out = self.diffusion.predict_action(obs_dict) - else: - out = self.ema_diffusion.predict_action(obs_dict) - action = out["action"] + self._queues = populate_queues(self._queues, batch) + + if len(self._queues["action"]) == 0: + # stack n latest observations from the queue + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} + + obs_dict = { + "image": batch["observation.image"], + "agent_pos": batch["observation.state"], + } + if self.training: + out = self.diffusion.predict_action(obs_dict) + else: + out = self.ema_diffusion.predict_action(obs_dict) + self._queues["action"].extend(out["action"].transpose(0, 1)) + + action = self._queues["action"].popleft() return action - def update(self, replay_buffer, step): + def forward(self, batch, step): start_time = time.time() self.diffusion.train() - num_slices = self.cfg.batch_size - batch_size = self.cfg.horizon * num_slices - - assert batch_size % self.cfg.horizon == 0 - assert batch_size % num_slices == 0 - - def process_batch(batch, horizon, num_slices): - # trajectory t = 64, horizon h = 16 - # (t h) ... -> t h ... - batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous() - - # |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16 - # |o|o| observations: 2 - # | |a|a|a|a|a|a|a|a| actions executed: 8 - # |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16 - # note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model - - image = batch["observation", "image"] - state = batch["observation", "state"] - action = batch["action"] - assert image.shape[1] == horizon - assert state.shape[1] == horizon - assert action.shape[1] == horizon - - if not (horizon == 16 and self.cfg.n_obs_steps == 2): - raise NotImplementedError() - - # keep first 2 observations of the slice corresponding to t=[-1,0] - image = image[:, : self.cfg.n_obs_steps] - state = state[:, : self.cfg.n_obs_steps] - - out = { - "obs": { - "image": image.to(self.device, non_blocking=True), - "agent_pos": state.to(self.device, non_blocking=True), - }, - "action": action.to(self.device, non_blocking=True), - } - return out - - batch = replay_buffer.sample(batch_size) - batch = process_batch(batch, self.cfg.horizon, num_slices) - - data_s = time.time() - start_time - loss = self.diffusion.compute_loss(batch) loss.backward() @@ -189,7 +171,6 @@ class DiffusionPolicy(AbstractPolicy): "loss": loss.item(), "grad_norm": float(grad_norm), "lr": self.lr_scheduler.get_last_lr()[0], - "data_s": data_s, "update_s": time.time() - start_time, } diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 934f0962..9077d4d0 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,11 +1,10 @@ def make_policy(cfg): - if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1: - raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.") - if cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPCPolicy - policy = TDMPCPolicy(cfg.policy, cfg.device) + policy = TDMPCPolicy( + cfg.policy, n_obs_steps=cfg.n_obs_steps, n_action_steps=cfg.n_action_steps, device=cfg.device + ) elif cfg.policy.name == "diffusion": from lerobot.common.policies.diffusion.policy import DiffusionPolicy @@ -17,24 +16,24 @@ def make_policy(cfg): cfg_obs_encoder=cfg.obs_encoder, cfg_optimizer=cfg.optimizer, cfg_ema=cfg.ema, - n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, + # n_obs_steps=cfg.n_obs_steps, + # n_action_steps=cfg.n_action_steps, **cfg.policy, ) elif cfg.policy.name == "act": from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy - policy = ActionChunkingTransformerPolicy( - cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps - ) + policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device) + policy.to(cfg.device) else: raise ValueError(cfg.policy.name) if cfg.policy.pretrained_model_path: # TODO(rcadene): hack for old pretrained models from fowm if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path: - if "offline" in cfg.pretrained_model_path: + if "offline" in cfg.policy.pretrained_model_path: policy.step[0] = 25000 - elif "final" in cfg.pretrained_model_path: + elif "final" in cfg.policy.pretrained_model_path: policy.step[0] = 100000 else: raise NotImplementedError() diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 64dcc94d..14728576 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -1,6 +1,7 @@ # ruff: noqa: N806 import time +from collections import deque from copy import deepcopy import einops @@ -9,7 +10,7 @@ import torch import torch.nn as nn import lerobot.common.policies.tdmpc.helper as h -from lerobot.common.policies.abstract import AbstractPolicy +from lerobot.common.policies.utils import populate_queues from lerobot.common.utils import get_safe_torch_device FIRST_FRAME = 0 @@ -87,16 +88,18 @@ class TOLD(nn.Module): return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2 -class TDMPCPolicy(AbstractPolicy): +class TDMPCPolicy(nn.Module): """Implementation of TD-MPC learning + inference.""" name = "tdmpc" - def __init__(self, cfg, device): - super().__init__(None) + def __init__(self, cfg, n_obs_steps, n_action_steps, device): + super().__init__() self.action_dim = cfg.action_dim self.cfg = cfg + self.n_obs_steps = n_obs_steps + self.n_action_steps = n_action_steps self.device = get_safe_torch_device(device) self.std = h.linear_schedule(cfg.std_schedule, 0) self.model = TOLD(cfg) @@ -107,7 +110,6 @@ class TDMPCPolicy(AbstractPolicy): # self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) self.model.eval() self.model_target.eval() - self.batch_size = cfg.batch_size self.register_buffer("step", torch.zeros(1)) @@ -128,20 +130,54 @@ class TDMPCPolicy(AbstractPolicy): self.model.load_state_dict(d["model"]) self.model_target.load_state_dict(d["model_target"]) - @torch.no_grad() - def select_actions(self, observation, step_count): - if observation["image"].shape[0] != 1: - raise NotImplementedError("Batch size > 1 not handled") - - t0 = step_count.item() == 0 - - obs = { - # TODO(rcadene): remove contiguous hack... - "rgb": observation["image"].contiguous(), - "state": observation["state"].contiguous(), + def reset(self): + """ + Clear observation and action queues. Should be called on `env.reset()` + """ + self._queues = { + "observation.image": deque(maxlen=self.n_obs_steps), + "observation.state": deque(maxlen=self.n_obs_steps), + "action": deque(maxlen=self.n_action_steps), } - # Note: unsqueeze needed because `act` still uses non-batch logic. - action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0) + + @torch.no_grad() + def select_action(self, batch, step): + assert "observation.image" in batch + assert "observation.state" in batch + assert len(batch) == 2 + + self._queues = populate_queues(self._queues, batch) + + t0 = step == 0 + + self.eval() + + if len(self._queues["action"]) == 0: + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} + + if self.n_obs_steps == 1: + # hack to remove the time dimension + for key in batch: + assert batch[key].shape[1] == 1 + batch[key] = batch[key][:, 0] + + actions = [] + batch_size = batch["observation.image"].shape[0] + for i in range(batch_size): + obs = { + "rgb": batch["observation.image"][[i]], + "state": batch["observation.state"][[i]], + } + # Note: unsqueeze needed because `act` still uses non-batch logic. + action = self.act(obs, t0=t0, step=self.step) + actions.append(action) + action = torch.stack(actions) + + # tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time + if i in range(self.n_action_steps): + self._queues["action"].append(action) + + action = self._queues["action"].popleft() return action @torch.no_grad() @@ -290,117 +326,54 @@ class TDMPCPolicy(AbstractPolicy): def _td_target(self, next_z, reward, mask): """Compute the TD-target from a reward and the observation at the following time step.""" next_v = self.model.V(next_z) - td_target = reward + self.cfg.discount * mask * next_v + td_target = reward + self.cfg.discount * mask * next_v.squeeze(2) return td_target - def update(self, replay_buffer, step, demo_buffer=None): + def forward(self, batch, step): """Main update function. Corresponds to one iteration of the model learning.""" start_time = time.time() - num_slices = self.cfg.batch_size - batch_size = self.cfg.horizon * num_slices + batch_size = batch["index"].shape[0] - if demo_buffer is None: - demo_batch_size = 0 - else: - # Update oversampling ratio - demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step) - demo_num_slices = int(demo_pc_batch * self.batch_size) - demo_batch_size = self.cfg.horizon * demo_num_slices - batch_size -= demo_batch_size - num_slices -= demo_num_slices - replay_buffer._sampler.num_slices = num_slices - demo_buffer._sampler.num_slices = demo_num_slices + # TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels) + # instead of currently (time/horizon, batch size, channels) which is not the pytorch convention + # batch size b = 256, time/horizon t = 5 + # b t ... -> t b ... + for key in batch: + if batch[key].ndim > 1: + batch[key] = batch[key].transpose(1, 0) - assert demo_batch_size % self.cfg.horizon == 0 - assert demo_batch_size % demo_num_slices == 0 + action = batch["action"] + reward = batch["next.reward"] + # idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights + done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) + mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) + weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device) - assert batch_size % self.cfg.horizon == 0 - assert batch_size % num_slices == 0 + obses = { + "rgb": batch["observation.image"], + "state": batch["observation.state"], + } - # Sample from interaction dataset - - def process_batch(batch, horizon, num_slices): - # trajectory t = 256, horizon h = 5 - # (t h) ... -> h t ... - batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() - - obs = { - "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True), - "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True), - } - action = batch["action"].to(self.device, non_blocking=True) - next_obses = { - "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True), - "state": batch["next", "observation", "state"].to(self.device, non_blocking=True), - } - reward = batch["next", "reward"].to(self.device, non_blocking=True) - - idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True) - weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True) - - # TODO(rcadene): rearrange directly in offline dataset - if reward.ndim == 2: - reward = einops.rearrange(reward, "h t -> h t 1") - - assert reward.ndim == 3 - assert reward.shape == (horizon, num_slices, 1) - # We dont use `batch["next", "done"]` since it only indicates the end of an - # episode, but not the end of the trajectory of an episode. - # Neither does `batch["next", "terminated"]` - done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) - mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) - return obs, action, next_obses, reward, mask, done, idxs, weights - - batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() - - obs, action, next_obses, reward, mask, done, idxs, weights = process_batch( - batch, self.cfg.horizon, num_slices - ) - - # Sample from demonstration dataset - if demo_batch_size > 0: - demo_batch = demo_buffer.sample(demo_batch_size) - ( - demo_obs, - demo_action, - demo_next_obses, - demo_reward, - demo_mask, - demo_done, - demo_idxs, - demo_weights, - ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices) - - if isinstance(obs, dict): - obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs} - next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses} - else: - obs = torch.cat([obs, demo_obs]) - next_obses = torch.cat([next_obses, demo_next_obses], dim=1) - action = torch.cat([action, demo_action], dim=1) - reward = torch.cat([reward, demo_reward], dim=1) - mask = torch.cat([mask, demo_mask], dim=1) - done = torch.cat([done, demo_done], dim=1) - idxs = torch.cat([idxs, demo_idxs]) - weights = torch.cat([weights, demo_weights]) + shapes = {} + for k in obses: + shapes[k] = obses[k].shape + obses[k] = einops.rearrange(obses[k], "t b ... -> (t b) ... ") # Apply augmentations aug_tf = h.aug(self.cfg) - obs = aug_tf(obs) + obses = aug_tf(obses) - for k in next_obses: - next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...") - next_obses = aug_tf(next_obses) - for k in next_obses: - next_obses[k] = einops.rearrange( - next_obses[k], - "(h t) ... -> h t ...", - h=self.cfg.horizon, - t=self.cfg.batch_size, - ) + for k in obses: + t, b = shapes[k][:2] + obses[k] = einops.rearrange(obses[k], "(t b) ... -> t b ... ", b=b, t=t) - horizon = self.cfg.horizon + obs, next_obses = {}, {} + for k in obses: + obs[k] = obses[k][0] + next_obses[k] = obses[k][1:].clone() + + horizon = next_obses["rgb"].shape[0] loss_mask = torch.ones_like(mask, device=self.device) for t in range(1, horizon): loss_mask[t] = loss_mask[t - 1] * (~done[t - 1]) @@ -418,7 +391,7 @@ class TDMPCPolicy(AbstractPolicy): td_targets = self._td_target(next_z, reward, mask) # Latent rollout - zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device) + zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device) reward_preds = torch.empty_like(reward, device=self.device) assert reward.shape[0] == horizon z = self.model.encode(obs) @@ -427,22 +400,21 @@ class TDMPCPolicy(AbstractPolicy): for t in range(horizon): z, reward_pred = self.model.next(z, action[t]) zs[t + 1] = z - reward_preds[t] = reward_pred + reward_preds[t] = reward_pred.squeeze(1) with torch.no_grad(): v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min") # Predictions qs = self.model.Q(zs[:-1], action, return_type="all") + qs = qs.squeeze(3) value_info["Q"] = qs.mean().item() v = self.model.V(zs[:-1]) value_info["V"] = v.mean().item() # Losses - rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1) - consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum( - dim=0 - ) + rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1) + consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0) reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0) q_value_loss, priority_loss = 0, 0 for q in range(self.cfg.num_q): @@ -450,7 +422,9 @@ class TDMPCPolicy(AbstractPolicy): priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0) expectile = h.linear_schedule(self.cfg.expectile, step) - v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0) + v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum( + dim=0 + ) total_loss = ( self.cfg.consistency_coef * consistency_loss @@ -459,7 +433,7 @@ class TDMPCPolicy(AbstractPolicy): + self.cfg.value_coef * v_value_loss ) - weighted_loss = (total_loss.squeeze(1) * weights).mean() + weighted_loss = (total_loss * weights).mean() weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon)) has_nan = torch.isnan(weighted_loss).item() if has_nan: @@ -472,19 +446,20 @@ class TDMPCPolicy(AbstractPolicy): ) self.optim.step() - if self.cfg.per: - # Update priorities - priorities = priority_loss.clamp(max=1e4).detach() - has_nan = torch.isnan(priorities).any().item() - if has_nan: - print(f"priorities has nan: {priorities=}") - else: - replay_buffer.update_priority( - idxs[:num_slices], - priorities[:num_slices], - ) - if demo_batch_size > 0: - demo_buffer.update_priority(demo_idxs, priorities[num_slices:]) + # TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion + # if self.cfg.per: + # # Update priorities + # priorities = priority_loss.clamp(max=1e4).detach() + # has_nan = torch.isnan(priorities).any().item() + # if has_nan: + # print(f"priorities has nan: {priorities=}") + # else: + # replay_buffer.update_priority( + # idxs[:num_slices], + # priorities[:num_slices], + # ) + # if demo_batch_size > 0: + # demo_buffer.update_priority(demo_idxs, priorities[num_slices:]) # Update policy + target network _, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action) @@ -507,7 +482,7 @@ class TDMPCPolicy(AbstractPolicy): "data_s": data_s, "update_s": time.time() - start_time, } - info["demo_batch_size"] = demo_batch_size + # info["demo_batch_size"] = demo_batch_size info["expectile"] = expectile info.update(value_info) info.update(pi_update_info) diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py new file mode 100644 index 00000000..b0503fe4 --- /dev/null +++ b/lerobot/common/policies/utils.py @@ -0,0 +1,10 @@ +def populate_queues(queues, batch): + for key in batch: + if len(queues[key]) != queues[key].maxlen: + # initialize by copying the first observation several times until the queue is full + while len(queues[key]) != queues[key].maxlen: + queues[key].append(batch[key]) + else: + # add latest observation to the queue + queues[key].append(batch[key]) + return queues diff --git a/lerobot/common/transforms.py b/lerobot/common/transforms.py index 4832c91b..ec967614 100644 --- a/lerobot/common/transforms.py +++ b/lerobot/common/transforms.py @@ -1,53 +1,49 @@ -from typing import Sequence - import torch -from tensordict import TensorDictBase -from tensordict.nn import dispatch -from tensordict.utils import NestedKey -from torchrl.envs.transforms import ObservationTransform, Transform +from torchvision.transforms.v2 import Compose, Transform -class Prod(ObservationTransform): +def apply_inverse_transform(item, transform): + transforms = transform.transforms if isinstance(transform, Compose) else [transform] + for tf in transforms[::-1]: + if tf.invertible: + item = tf.inverse_transform(item) + else: + raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).") + return item + + +class Prod(Transform): invertible = True - def __init__(self, in_keys: Sequence[NestedKey], prod: float): + def __init__(self, in_keys: list[str], prod: float): super().__init__() self.in_keys = in_keys self.prod = prod self.original_dtypes = {} - def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase: - # _reset is called once when the environment reset to normalize the first observation - tensordict_reset = self._call(tensordict_reset) - return tensordict_reset - - @dispatch(source="in_keys", dest="out_keys") - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - return self._call(tensordict) - - def _call(self, td): + def forward(self, item): for key in self.in_keys: - if td.get(key, None) is None: + if key not in item: continue - self.original_dtypes[key] = td[key].dtype - td[key] = td[key].type(torch.float32) * self.prod - return td + self.original_dtypes[key] = item[key].dtype + item[key] = item[key].type(torch.float32) * self.prod + return item - def _inv_call(self, td: TensorDictBase) -> TensorDictBase: + def inverse_transform(self, item): for key in self.in_keys: - if td.get(key, None) is None: + if key not in item: continue - td[key] = (td[key] / self.prod).type(self.original_dtypes[key]) - return td + item[key] = (item[key] / self.prod).type(self.original_dtypes[key]) + return item - def transform_observation_spec(self, obs_spec): - for key in self.in_keys: - if obs_spec.get(key, None) is None: - continue - obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod - obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod - obs_spec[key].dtype = torch.float32 - return obs_spec + # def transform_observation_spec(self, obs_spec): + # for key in self.in_keys: + # if obs_spec.get(key, None) is None: + # continue + # obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod + # obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod + # obs_spec[key].dtype = torch.float32 + # return obs_spec class NormalizeTransform(Transform): @@ -55,65 +51,50 @@ class NormalizeTransform(Transform): def __init__( self, - stats: TensorDictBase, - in_keys: Sequence[NestedKey] = None, - out_keys: Sequence[NestedKey] | None = None, - in_keys_inv: Sequence[NestedKey] | None = None, - out_keys_inv: Sequence[NestedKey] | None = None, + stats: dict, + in_keys: list[str] = None, + out_keys: list[str] | None = None, + in_keys_inv: list[str] | None = None, + out_keys_inv: list[str] | None = None, mode="mean_std", ): - if out_keys is None: - out_keys = in_keys - if in_keys_inv is None: - in_keys_inv = out_keys - if out_keys_inv is None: - out_keys_inv = in_keys - super().__init__( - in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv - ) + super().__init__() + self.in_keys = in_keys + self.out_keys = in_keys if out_keys is None else out_keys + self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv + self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv self.stats = stats assert mode in ["mean_std", "min_max"] self.mode = mode - def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase: - # _reset is called once when the environment reset to normalize the first observation - tensordict_reset = self._call(tensordict_reset) - return tensordict_reset - - @dispatch(source="in_keys", dest="out_keys") - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - return self._call(tensordict) - - def _call(self, td: TensorDictBase) -> TensorDictBase: + def forward(self, item): for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False): - # TODO(rcadene): don't know how to do `inkey not in td` - if td.get(inkey, None) is None: + if inkey not in item: continue if self.mode == "mean_std": mean = self.stats[inkey]["mean"] std = self.stats[inkey]["std"] - td[outkey] = (td[inkey] - mean) / (std + 1e-8) + item[outkey] = (item[inkey] - mean) / (std + 1e-8) else: min = self.stats[inkey]["min"] max = self.stats[inkey]["max"] # normalize to [0,1] - td[outkey] = (td[inkey] - min) / (max - min) + item[outkey] = (item[inkey] - min) / (max - min) # normalize to [-1, 1] - td[outkey] = td[outkey] * 2 - 1 - return td + item[outkey] = item[outkey] * 2 - 1 + return item - def _inv_call(self, td: TensorDictBase) -> TensorDictBase: + def inverse_transform(self, item): for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False): - # TODO(rcadene): don't know how to do `inkey not in td` - if td.get(inkey, None) is None: + if inkey not in item: continue if self.mode == "mean_std": mean = self.stats[inkey]["mean"] std = self.stats[inkey]["std"] - td[outkey] = td[inkey] * std + mean + item[outkey] = item[inkey] * std + mean else: min = self.stats[inkey]["min"] max = self.stats[inkey]["max"] - td[outkey] = (td[inkey] + 1) / 2 - td[outkey] = td[outkey] * (max - min) + min - return td + item[outkey] = (item[inkey] + 1) / 2 + item[outkey] = item[outkey] * (max - min) + min + return item diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index 7ed29334..373a3bbc 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -95,3 +95,16 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D ) cfg = hydra.compose(Path(config_path).stem, overrides) return cfg + + +def print_cuda_memory_usage(): + """Use this function to locate and debug memory leak.""" + import gc + + gc.collect() + # Also clear the cache if you want to fully release the memory + torch.cuda.empty_cache() + print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2)) + print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2)) + print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2)) + print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2)) diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 51569fea..6b836795 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -4,7 +4,7 @@ eval_episodes: 50 eval_freq: 7500 save_freq: 75000 log_freq: 250 -# TODO: same as simxarm, need to adjust +# TODO: same as xarm, need to adjust offline_steps: 25000 online_steps: 25000 @@ -14,11 +14,10 @@ dataset_id: aloha_sim_insertion_human env: name: aloha - task: sim_insertion + task: AlohaInsertion-v0 from_pixels: True pixels_only: False image_size: [3, 480, 640] - action_repeat: 1 episode_length: 400 fps: ${fps} diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index 0050530e..a7097ffd 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -4,7 +4,7 @@ eval_episodes: 50 eval_freq: 7500 save_freq: 75000 log_freq: 250 -# TODO: same as simxarm, need to adjust +# TODO: same as xarm, need to adjust offline_steps: 25000 online_steps: 25000 @@ -14,11 +14,10 @@ dataset_id: pusht env: name: pusht - task: pusht + task: PushT-v0 from_pixels: True pixels_only: False image_size: 96 - action_repeat: 1 episode_length: 300 fps: ${fps} diff --git a/lerobot/configs/env/simxarm.yaml b/lerobot/configs/env/xarm.yaml similarity index 86% rename from lerobot/configs/env/simxarm.yaml rename to lerobot/configs/env/xarm.yaml index f79db8f7..bcba659e 100644 --- a/lerobot/configs/env/simxarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -12,12 +12,11 @@ fps: 15 dataset_id: xarm_lift_medium env: - name: simxarm - task: lift + name: xarm + task: XarmLift-v0 from_pixels: True pixels_only: False image_size: 84 - action_repeat: 2 episode_length: 25 fps: ${fps} diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index a52c3f54..e2074b46 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -1,6 +1,6 @@ # @package _global_ -offline_steps: 1344000 +offline_steps: 80000 online_steps: 0 eval_episodes: 1 @@ -10,7 +10,6 @@ log_freq: 250 horizon: 100 n_obs_steps: 1 -n_latency_steps: 0 # when temporal_agg=False, n_action_steps=horizon n_action_steps: ${horizon} @@ -21,26 +20,27 @@ policy: lr: 1e-5 lr_backbone: 1e-5 + pretrained_backbone: true weight_decay: 1e-4 grad_clip_norm: 10 backbone: resnet18 - num_queries: ${horizon} # chunk_size horizon: ${horizon} # chunk_size kl_weight: 10 - hidden_dim: 512 + d_model: 512 dim_feedforward: 3200 + vae_enc_layers: 4 enc_layers: 4 - dec_layers: 7 - nheads: 8 + dec_layers: 1 + num_heads: 8 #camera_names: [top, front_close, left_pillar, right_pillar] camera_names: [top] - position_embedding: sine - masks: false dilation: false dropout: 0.1 pre_norm: false + activation: relu + latent_dim: 32 - vae: true + use_vae: true batch_size: 8 @@ -51,8 +51,18 @@ policy: utd: 1 n_obs_steps: ${n_obs_steps} + n_action_steps: ${n_action_steps} temporal_agg: false - state_dim: ??? - action_dim: ??? + state_dim: 14 + action_dim: 14 + + image_normalization: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + + delta_timestamps: + observation.images.top: [0.0] + observation.state: [0.0] + action: "[i / ${fps} for i in range(${horizon})]" diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 4d6eedca..811ee824 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -16,7 +16,6 @@ seed: 100000 horizon: 16 n_obs_steps: 2 n_action_steps: 8 -n_latency_steps: 0 dataset_obs_steps: ${n_obs_steps} past_action_visible: False keypoint_visible_rate: 1.0 @@ -38,8 +37,8 @@ policy: shape_meta: ${shape_meta} horizon: ${horizon} - # n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'} n_obs_steps: ${n_obs_steps} + n_action_steps: ${n_action_steps} num_inference_steps: 100 obs_as_global_cond: ${obs_as_global_cond} # crop_shape: null @@ -64,6 +63,11 @@ policy: lr_warmup_steps: 500 grad_clip_norm: 10 + delta_timestamps: + observation.image: [-0.1, 0] + observation.state: [-0.1, 0] + action: [-0.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4] + noise_scheduler: _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler num_train_timesteps: 100 diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index ff0e6b04..4fd2b6bb 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -1,6 +1,6 @@ # @package _global_ -n_action_steps: 1 +n_action_steps: 2 n_obs_steps: 1 policy: @@ -77,3 +77,9 @@ policy: num_q: 5 mlp_dim: 512 latent_dim: 50 + + delta_timestamps: + observation.image: "[i / ${fps} for i in range(6)]" + observation.state: "[i / ${fps} for i in range(6)]" + action: "[i / ${fps} for i in range(5)]" + next.reward: "[i / ${fps} for i in range(5)]" diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 216769d6..d676623e 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -32,23 +32,21 @@ import json import logging import threading import time +from copy import deepcopy from datetime import datetime as dt from pathlib import Path import einops +import gymnasium as gym import imageio import numpy as np import torch -import tqdm from huggingface_hub import snapshot_download -from tensordict.nn import TensorDictModule -from torchrl.envs import EnvBase -from torchrl.envs.batched_envs import BatchedEnvBase -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.factory import make_env +from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.logger import log_output_dir -from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.factory import make_policy from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed @@ -58,89 +56,200 @@ def write_video(video_path, stacked_frames, fps): def eval_policy( - env: BatchedEnvBase, - policy: AbstractPolicy, - num_episodes: int = 10, - max_steps: int = 30, - save_video: bool = False, + env: gym.vector.VectorEnv, + policy: torch.nn.Module, + max_episodes_rendered: int = 0, video_dir: Path = None, - fps: int = 15, - return_first_video: bool = False, + # TODO(rcadene): make it possible to overwrite fps? we should use env.fps + transform: callable = None, + seed=None, ): + fps = env.unwrapped.metadata["render_fps"] + if policy is not None: policy.eval() + device = "cpu" if policy is None else next(policy.parameters()).device + start = time.time() sum_rewards = [] max_rewards = [] - successes = [] + all_successes = [] seeds = [] threads = [] # for video saving threads episode_counter = 0 # for saving the correct number of videos + num_episodes = len(env.envs) + # TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than # needed as I'm currently taking a ceil. - for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))): - ep_frames = [] + ep_frames = [] - def maybe_render_frame(env: EnvBase, _): - if save_video or (return_first_video and i == 0): # noqa: B023 - ep_frames.append(env.render()) # noqa: B023 + def render_frame(env): + # noqa: B023 + eps_rendered = min(max_episodes_rendered, len(env.envs)) + visu = np.stack([env.envs[i].render() for i in range(eps_rendered)]) + ep_frames.append(visu) # noqa: B023 - # Clear the policy's action queue before the start of a new rollout. - if policy is not None: - policy.clear_action_queue() + for _ in range(num_episodes): + seeds.append("TODO") - if env.is_closed: - env.start() # needed to be able to get the seeds the first time as BatchedEnvs are lazy - seeds.extend(env._next_seed) + if hasattr(policy, "reset"): + policy.reset() + else: + logging.warning( + f"Policy {policy} doesnt have a `reset` method. It is required if the policy relies on an internal state during rollout." + ) + + # reset the environment + observation, info = env.reset(seed=seed) + if max_episodes_rendered > 0: + render_frame(env) + + observations = [] + actions = [] + # episode + # frame_id + # timestamp + rewards = [] + successes = [] + dones = [] + + done = torch.tensor([False for _ in env.envs]) + step = 0 + while not done.all(): + # format from env keys to lerobot keys + observation = preprocess_observation(observation) + observations.append(deepcopy(observation)) + + # apply transform to normalize the observations + for key in observation: + observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]]) + + # send observation to device/gpu + observation = {key: observation[key].to(device, non_blocking=True) for key in observation} + + # get the next action for the environment with torch.inference_mode(): - # TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all - # envs are done the first time. But we only use the first rollout. This is a waste of compute. - rollout = env.rollout( - max_steps=max_steps, - policy=policy, - auto_cast_to_device=True, - callback=maybe_render_frame, - break_when_any_done=env.batch_size[0] == 1, + action = policy.select_action(observation, step) + + # apply inverse transform to unnormalize the action + action = postprocess_action(action, transform) + + # apply the next action + observation, reward, terminated, truncated, info = env.step(action) + if max_episodes_rendered > 0: + render_frame(env) + + # TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?) + action = torch.from_numpy(action) + reward = torch.from_numpy(reward) + terminated = torch.from_numpy(terminated) + truncated = torch.from_numpy(truncated) + # environment is considered done (no more steps), when success state is reached (terminated is True), + # or time limit is reached (truncated is True), or it was previsouly done. + done = terminated | truncated | done + + if "final_info" in info: + # VectorEnv stores is_success into `info["final_info"][env_id]["is_success"]` instead of `info["is_success"]` + success = [ + env_info["is_success"] if env_info is not None else False for env_info in info["final_info"] + ] + else: + success = [False for _ in env.envs] + success = torch.tensor(success) + + actions.append(action) + rewards.append(reward) + dones.append(done) + successes.append(success) + + step += 1 + + env.close() + + # add the last observation when the env is done + observation = preprocess_observation(observation) + observations.append(deepcopy(observation)) + + new_obses = {} + for key in observations[0].keys(): # noqa: SIM118 + new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1) + observations = new_obses + actions = torch.stack(actions, dim=1) + rewards = torch.stack(rewards, dim=1) + successes = torch.stack(successes, dim=1) + dones = torch.stack(dones, dim=1) + + # Figure out where in each rollout sequence the first done condition was encountered (results after + # this won't be included). + # Note: this assumes that the shape of the done key is (batch_size, max_steps). + # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. + done_indices = torch.argmax(dones.to(int), axis=1) # (batch_size, rollout_steps) + expand_done_indices = done_indices[:, None].expand(-1, step) + expand_step_indices = torch.arange(step)[None, :].expand(num_episodes, -1) + mask = (expand_step_indices <= expand_done_indices).int() # (batch_size, rollout_steps) + batch_sum_reward = einops.reduce((rewards * mask), "b n -> b", "sum") + batch_max_reward = einops.reduce((rewards * mask), "b n -> b", "max") + batch_success = einops.reduce((successes * mask), "b n -> b", "any") + sum_rewards.extend(batch_sum_reward.tolist()) + max_rewards.extend(batch_max_reward.tolist()) + all_successes.extend(batch_success.tolist()) + + # similar logic is implemented in dataset preprocessing + ep_dicts = [] + num_episodes = dones.shape[0] + total_frames = 0 + idx0 = idx1 = 0 + data_ids_per_episode = {} + for ep_id in range(num_episodes): + num_frames = done_indices[ep_id].item() + 1 + # TODO(rcadene): We need to add a missing last frame which is the observation + # of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state" + ep_dict = { + "action": actions[ep_id, :num_frames], + "episode": torch.tensor([ep_id] * num_frames), + "frame_id": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / fps, + "next.done": dones[ep_id, :num_frames], + "next.reward": rewards[ep_id, :num_frames].type(torch.float32), + } + for key in observations: + ep_dict[key] = observations[key][ep_id, :num_frames] + ep_dicts.append(ep_dict) + + total_frames += num_frames + idx1 += num_frames + + data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1) + + idx0 = idx1 + + # similar logic is implemented in dataset preprocessing + data_dict = {} + keys = ep_dicts[0].keys() + for key in keys: + data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + data_dict["index"] = torch.arange(0, total_frames, 1) + + if max_episodes_rendered > 0: + batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *) + + for stacked_frames, done_index in zip( + batch_stacked_frames, done_indices.flatten().tolist(), strict=False + ): + if episode_counter >= num_episodes: + continue + video_dir.mkdir(parents=True, exist_ok=True) + video_path = video_dir / f"eval_episode_{episode_counter}.mp4" + thread = threading.Thread( + target=write_video, + args=(str(video_path), stacked_frames[:done_index], fps), ) - # Figure out where in each rollout sequence the first done condition was encountered (results after - # this won't be included). - # Note: this assumes that the shape of the done key is (batch_size, max_steps, 1). - # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. - rollout_steps = rollout["next", "done"].shape[1] - done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps) - mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1) - batch_sum_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "sum") - batch_max_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "max") - batch_success = einops.reduce((rollout["next", "success"] * mask), "b n 1 -> b", "any") - sum_rewards.extend(batch_sum_reward.tolist()) - max_rewards.extend(batch_max_reward.tolist()) - successes.extend(batch_success.tolist()) + thread.start() + threads.append(thread) + episode_counter += 1 - if save_video or (return_first_video and i == 0): - batch_stacked_frames = np.stack(ep_frames) # (t, b, *) - batch_stacked_frames = batch_stacked_frames.transpose( - 1, 0, *range(2, batch_stacked_frames.ndim) - ) # (b, t, *) - - if save_video: - for stacked_frames, done_index in zip( - batch_stacked_frames, done_indices.flatten().tolist(), strict=False - ): - if episode_counter >= num_episodes: - continue - video_dir.mkdir(parents=True, exist_ok=True) - video_path = video_dir / f"eval_episode_{episode_counter}.mp4" - thread = threading.Thread( - target=write_video, - args=(str(video_path), stacked_frames[:done_index], fps), - ) - thread.start() - threads.append(thread) - episode_counter += 1 - - if return_first_video and i == 0: - first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2) + videos = einops.rearrange(batch_stacked_frames, "b t h w c -> b t c h w") for thread in threads: thread.join() @@ -158,22 +267,26 @@ def eval_policy( zip( sum_rewards[:num_episodes], max_rewards[:num_episodes], - successes[:num_episodes], + all_successes[:num_episodes], seeds[:num_episodes], strict=True, ) ) ], "aggregated": { - "avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]), - "avg_max_reward": np.nanmean(max_rewards[:num_episodes]), - "pc_success": np.nanmean(successes[:num_episodes]) * 100, + "avg_sum_reward": float(np.nanmean(sum_rewards[:num_episodes])), + "avg_max_reward": float(np.nanmean(max_rewards[:num_episodes])), + "pc_success": float(np.nanmean(all_successes[:num_episodes]) * 100), "eval_s": time.time() - start, "eval_ep_s": (time.time() - start) / num_episodes, }, + "episodes": { + "data_dict": data_dict, + "data_ids_per_episode": data_ids_per_episode, + }, } - if return_first_video: - return info, first_video + if max_episodes_rendered > 0: + info["videos"] = videos return info @@ -194,35 +307,29 @@ def eval(cfg: dict, out_dir=None, stats_path=None): logging.info("Making transforms.") # TODO(alexander-soare): Completely decouple datasets from evaluation. - offline_buffer = make_offline_buffer(cfg, stats_path=stats_path) + transform = make_dataset(cfg, stats_path=stats_path).transform logging.info("Making environment.") - env = make_env(cfg, transform=offline_buffer.transform) + env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) - if cfg.policy.pretrained_model_path: - policy = make_policy(cfg) - policy = TensorDictModule( - policy, - in_keys=["observation", "step_count"], - out_keys=["action"], - ) - else: - # when policy is None, rollout a random policy - policy = None + logging.info("Making policy.") + policy = make_policy(cfg) info = eval_policy( env, - policy=policy, - save_video=True, + policy, + max_episodes_rendered=10, video_dir=Path(out_dir) / "eval", - fps=cfg.env.fps, - max_steps=cfg.env.episode_length, - num_episodes=cfg.eval_episodes, + transform=transform, + seed=cfg.seed, ) print(info["aggregated"]) # Save info with open(Path(out_dir) / "eval_info.json", "w") as f: + # remove pytorch tensors which are not serializable to save the evaluation results only + del info["episodes"] + del info["videos"] json.dump(info, f, indent=2) logging.info("End of eval") diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 18c3715b..03506f2a 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,18 +1,21 @@ import logging +from copy import deepcopy from pathlib import Path import hydra -import numpy as np import torch -from tensordict.nn import TensorDictModule -from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer -from torchrl.data.replay_buffers import PrioritizedSliceSampler -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_global_seed +from lerobot.common.utils import ( + format_big_number, + get_safe_torch_device, + init_logging, + set_global_seed, +) from lerobot.scripts.eval import eval_policy @@ -34,19 +37,18 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa train(cfg, out_dir=out_dir, job_name=job_name) -def log_train_info(logger, info, step, cfg, offline_buffer, is_offline): +def log_train_info(logger, info, step, cfg, dataset, is_offline): loss = info["loss"] grad_norm = info["grad_norm"] lr = info["lr"] - data_s = info["data_s"] update_s = info["update_s"] # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. num_samples = (step + 1) * cfg.policy.batch_size - avg_samples_per_ep = offline_buffer.num_samples / offline_buffer.num_episodes + avg_samples_per_ep = dataset.num_samples / dataset.num_episodes num_episodes = num_samples / avg_samples_per_ep - num_epochs = num_samples / offline_buffer.num_samples + num_epochs = num_samples / dataset.num_samples log_items = [ f"step:{format_big_number(step)}", # number of samples seen during training @@ -59,7 +61,6 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline): f"grdn:{grad_norm:.3f}", f"lr:{lr:0.1e}", # in seconds - f"data_s:{data_s:.3f}", f"updt_s:{update_s:.3f}", ] logging.info(" ".join(log_items)) @@ -73,7 +74,7 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline): logger.log_dict(info, step, mode="train") -def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline): +def log_eval_info(logger, info, step, cfg, dataset, is_offline): eval_s = info["eval_s"] avg_sum_reward = info["avg_sum_reward"] pc_success = info["pc_success"] @@ -81,9 +82,9 @@ def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline): # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. num_samples = (step + 1) * cfg.policy.batch_size - avg_samples_per_ep = offline_buffer.num_samples / offline_buffer.num_episodes + avg_samples_per_ep = dataset.num_samples / dataset.num_episodes num_episodes = num_samples / avg_samples_per_ep - num_epochs = num_samples / offline_buffer.num_samples + num_epochs = num_samples / dataset.num_samples log_items = [ f"step:{format_big_number(step)}", # number of samples seen during training @@ -107,6 +108,64 @@ def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline): logger.log_dict(info, step, mode="eval") +def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float): + """ + Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average). + + Parameters: + - n_off (int): Number of offline samples, each with a sampling weight of 1. + - n_on (int): Number of online samples. + - pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5). + + The total weight of offline samples is n_off * 1.0. + The total weight of offline samples is n_on * w. + The total combined weight of all samples is n_off + n_on * w. + The fraction of the weight that is online is n_on * w / (n_off + n_on * w). + We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on. + The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1)) + """ + assert 0.0 <= pc_on <= 1.0 + return -(n_off * pc_on) / (n_on * (pc_on - 1)) + + +def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples): + data_dict = episodes["data_dict"] + data_ids_per_episode = episodes["data_ids_per_episode"] + + if len(online_dataset) == 0: + # initialize online dataset + online_dataset.data_dict = data_dict + online_dataset.data_ids_per_episode = data_ids_per_episode + else: + # find episode index and data frame indices according to previous episode in online_dataset + start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1 + start_index = online_dataset.data_dict["index"][-1].item() + 1 + data_dict["episode"] += start_episode + data_dict["index"] += start_index + + # extend online dataset + for key in data_dict: + # TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure + online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]]) + for ep_id in data_ids_per_episode: + online_dataset.data_ids_per_episode[ep_id + start_episode] = ( + data_ids_per_episode[ep_id] + start_index + ) + + # update the concatenated dataset length used during sampling + concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) + + # update the sampling weights for each frame so that online frames get sampled a certain percentage of times + len_online = len(online_dataset) + len_offline = len(concat_dataset) - len_online + weight_offline = 1.0 + weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples) + sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset)) + + # update the total number of samples used during sampling + sampler.num_samples = len(concat_dataset) + + def train(cfg: dict, out_dir=None, job_name=None): if out_dir is None: raise NotImplementedError() @@ -124,30 +183,11 @@ def train(cfg: dict, out_dir=None, job_name=None): torch.backends.cuda.matmul.allow_tf32 = True set_global_seed(cfg.seed) - logging.info("make_offline_buffer") - offline_buffer = make_offline_buffer(cfg) - - # TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy - if cfg.policy.balanced_sampling: - logging.info("make online_buffer") - num_traj_per_batch = cfg.policy.batch_size - - online_sampler = PrioritizedSliceSampler( - max_capacity=100_000, - alpha=cfg.policy.per_alpha, - beta=cfg.policy.per_beta, - num_slices=num_traj_per_batch, - strict_length=True, - ) - - online_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(100_000), - sampler=online_sampler, - transform=offline_buffer.transform, - ) + logging.info("make_dataset") + offline_dataset = make_dataset(cfg) logging.info("make_env") - env = make_env(cfg, transform=offline_buffer.transform) + env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) logging.info("make_policy") policy = make_policy(cfg) @@ -155,8 +195,6 @@ def train(cfg: dict, out_dir=None, job_name=None): num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - td_policy = TensorDictModule(policy, in_keys=["observation", "step_count"], out_keys=["action"]) - # log metrics to terminal and wandb logger = Logger(out_dir, job_name, cfg) @@ -164,9 +202,8 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.online_steps=}") - logging.info(f"{cfg.env.action_repeat=}") - logging.info(f"{offline_buffer.num_samples=} ({format_big_number(offline_buffer.num_samples)})") - logging.info(f"{offline_buffer.num_episodes=}") + logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})") + logging.info(f"{offline_dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") @@ -174,18 +211,17 @@ def train(cfg: dict, out_dir=None, job_name=None): def _maybe_eval_and_maybe_save(step): if step % cfg.eval_freq == 0: logging.info(f"Eval policy at step {step}") - eval_info, first_video = eval_policy( + eval_info = eval_policy( env, - td_policy, - num_episodes=cfg.eval_episodes, - max_steps=cfg.env.episode_length, - return_first_video=True, + policy, video_dir=Path(out_dir) / "eval", - save_video=True, + max_episodes_rendered=4, + transform=offline_dataset.transform, + seed=cfg.seed, ) - log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_buffer, is_offline) + log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) if cfg.wandb.enable: - logger.log_video(first_video, step, mode="eval") + logger.log_video(eval_info["videos"][0], step, mode="eval") logging.info("Resume training") if cfg.save_model and step % cfg.save_freq == 0: @@ -193,17 +229,33 @@ def train(cfg: dict, out_dir=None, job_name=None): logger.save_model(policy, identifier=step) logging.info("Resume training") - step = 0 # number of policy update (forward + backward + optim) + # create dataloader for offline training + dataloader = torch.utils.data.DataLoader( + offline_dataset, + num_workers=4, + batch_size=cfg.policy.batch_size, + shuffle=True, + pin_memory=cfg.device != "cpu", + drop_last=False, + ) + dl_iter = cycle(dataloader) + step = 0 # number of policy update (forward + backward + optim) is_offline = True for offline_step in range(cfg.offline_steps): if offline_step == 0: logging.info("Start offline training on a fixed dataset") - # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? policy.train() - train_info = policy.update(offline_buffer, step) + batch = next(dl_iter) + + for key in batch: + batch[key] = batch[key].to(cfg.device, non_blocking=True) + + train_info = policy(batch, step) + + # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: - log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) + log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline) # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in # step + 1. @@ -211,59 +263,60 @@ def train(cfg: dict, out_dir=None, job_name=None): step += 1 - demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None + # create an env dedicated to online episodes collection from policy rollout + rollout_env = make_env(cfg, num_parallel_envs=1) + + # create an empty online dataset similar to offline dataset + online_dataset = deepcopy(offline_dataset) + online_dataset.data_dict = {} + online_dataset.data_ids_per_episode = {} + + # create dataloader for online training + concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) + weights = [1.0] * len(concat_dataset) + sampler = torch.utils.data.WeightedRandomSampler( + weights, num_samples=len(concat_dataset), replacement=True + ) + dataloader = torch.utils.data.DataLoader( + concat_dataset, + num_workers=4, + batch_size=cfg.policy.batch_size, + sampler=sampler, + pin_memory=cfg.device != "cpu", + drop_last=False, + ) + dl_iter = cycle(dataloader) + online_step = 0 is_offline = False for env_step in range(cfg.online_steps): if env_step == 0: logging.info("Start online training by interacting with environment") - # TODO: add configurable number of rollout? (default=1) + with torch.no_grad(): - rollout = env.rollout( - max_steps=cfg.env.episode_length, - policy=td_policy, - auto_cast_to_device=True, + eval_info = eval_policy( + rollout_env, + policy, + transform=offline_dataset.transform, + seed=cfg.seed, ) - assert ( - len(rollout.batch_size) == 2 - ), "2 dimensions expected: number of env in parallel x max number of steps during rollout" - - num_parallel_env = rollout.batch_size[0] - if num_parallel_env != 1: - # TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests - raise NotImplementedError() - - num_max_steps = rollout.batch_size[1] - assert num_max_steps <= cfg.env.episode_length - - # reshape to have a list of steps to insert into online_buffer - rollout = rollout.reshape(num_parallel_env * num_max_steps) - - # set same episode index for all time steps contained in this rollout - rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) - online_buffer.extend(rollout) - - ep_sum_reward = rollout["next", "reward"].sum() - ep_max_reward = rollout["next", "reward"].max() - ep_success = rollout["next", "success"].any() - rollout_info = { - "avg_sum_reward": np.nanmean(ep_sum_reward), - "avg_max_reward": np.nanmean(ep_max_reward), - "pc_success": np.nanmean(ep_success) * 100, - "env_step": env_step, - "ep_length": len(rollout), - } + online_pc_sampling = cfg.get("demo_schedule", 0.5) + add_episodes_inplace( + eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling + ) for _ in range(cfg.policy.utd): - train_info = policy.update( - online_buffer, - step, - demo_buffer=demo_buffer, - ) + policy.train() + batch = next(dl_iter) + + for key in batch: + batch[key] = batch[key].to(cfg.device, non_blocking=True) + + train_info = policy(batch, step) + if step % cfg.log_freq == 0: - train_info.update(rollout_info) - log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) + log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass # in step + 1. diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 3dd7cdfa..4b7b7d6c 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -6,11 +6,8 @@ import einops import hydra import imageio import torch -from torchrl.data.replay_buffers import ( - SamplerWithoutReplacement, -) -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset from lerobot.common.logger import log_output_dir from lerobot.common.utils import init_logging @@ -39,85 +36,62 @@ def visualize_dataset(cfg: dict, out_dir=None): init_logging() log_output_dir(out_dir) - # we expect frames of each episode to be stored next to each others sequentially - sampler = SamplerWithoutReplacement( - shuffle=False, - ) - - logging.info("make_offline_buffer") - offline_buffer = make_offline_buffer( + logging.info("make_dataset") + dataset = make_dataset( cfg, - overwrite_sampler=sampler, # remove all transformations such as rescale images from [0,255] to [0,1] or normalization normalize=False, - overwrite_batch_size=1, - overwrite_prefetch=12, ) logging.info("Start rendering episodes from offline buffer") - video_paths = render_dataset(offline_buffer, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps) + video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps) for video_path in video_paths: logging.info(video_path) -def render_dataset(offline_buffer, out_dir, max_num_samples, fps): +def render_dataset(dataset, out_dir, max_num_episodes): out_dir = Path(out_dir) video_paths = [] threads = [] - frames = {} - current_ep_idx = 0 - logging.info(f"Visualizing episode {current_ep_idx}") - for i in range(max_num_samples): - # TODO(rcadene): make it work with bsize > 1 - ep_td = offline_buffer.sample(1) - ep_idx = ep_td["episode"][FIRST_FRAME].item() - # TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames - num_frames_left = offline_buffer._sampler._sample_list.numel() - episode_is_done = ep_idx != current_ep_idx + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=1, + shuffle=False, + ) + dl_iter = iter(dataloader) - if episode_is_done: - logging.info(f"Rendering episode {current_ep_idx}") + num_episodes = len(dataset.data_ids_per_episode) + for ep_id in range(min(max_num_episodes, num_episodes)): + logging.info(f"Rendering episode {ep_id}") - for im_key in offline_buffer.image_keys: - if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1): + frames = {} + for _ in dataset.data_ids_per_episode[ep_id]: + item = next(dl_iter) + + for im_key in dataset.image_keys: # when first frame of episode, initialize frames dict if im_key not in frames: frames[im_key] = [] # add current frame to list of frames to render - frames[im_key].append(ep_td[im_key]) + frames[im_key].append(item[im_key]) + + out_dir.mkdir(parents=True, exist_ok=True) + for im_key in dataset.image_keys: + if len(dataset.image_keys) > 1: + im_name = im_key.replace("observation.images.", "") + video_path = out_dir / f"episode_{ep_id}_{im_name}.mp4" else: - # When episode has no more frame in its list of observation, - # one frame still remains. It is the result of the last action taken. - # It is stored in `"next"`, so we add it to the list of frames to render. - frames[im_key].append(ep_td["next"][im_key]) + video_path = out_dir / f"episode_{ep_id}.mp4" + video_paths.append(video_path) - out_dir.mkdir(parents=True, exist_ok=True) - if len(offline_buffer.image_keys) > 1: - camera = im_key[-1] - video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4" - else: - video_path = out_dir / f"episode_{current_ep_idx}.mp4" - video_paths.append(str(video_path)) - - thread = threading.Thread( - target=cat_and_write_video, - args=(str(video_path), frames[im_key], fps), - ) - thread.start() - threads.append(thread) - - current_ep_idx = ep_idx - - # reset list of frames - del frames[im_key] - - if num_frames_left == 0: - logging.info("Ran out of frames") - break - - if current_ep_idx == NUM_EPISODES_TO_RENDER: - break + thread = threading.Thread( + target=cat_and_write_video, + args=(str(video_path), frames[im_key], dataset.fps), + ) + thread.start() + threads.append(thread) for thread in threads: thread.join() diff --git a/poetry.lock b/poetry.lock index 72397001..0133b3ed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -521,7 +521,7 @@ files = [ name = "dm-control" version = "1.0.14" description = "Continuous control environments and MuJoCo Python bindings." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "dm_control-1.0.14-py3-none-any.whl", hash = "sha256:883c63244a7ebf598700a97564ed19fffd3479ca79efd090aed881609cdb9fc6"}, @@ -552,7 +552,7 @@ hdf5 = ["h5py"] name = "dm-env" version = "1.6" description = "A Python interface for Reinforcement Learning environments." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "dm-env-1.6.tar.gz", hash = "sha256:a436eb1c654c39e0c986a516cee218bea7140b510fceff63f97eb4fcff3d93de"}, @@ -568,7 +568,7 @@ numpy = "*" name = "dm-tree" version = "0.1.8" description = "Tree is a library for working with nested data structures." -optional = false +optional = true python-versions = "*" files = [ {file = "dm-tree-0.1.8.tar.gz", hash = "sha256:0fcaabbb14e7980377439e7140bd05552739ca5e515ecb3119f234acee4b9430"}, @@ -692,18 +692,18 @@ files = [ [[package]] name = "filelock" -version = "3.13.1" +version = "3.13.3" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, - {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"}, + {file = "filelock-3.13.3-py3-none-any.whl", hash = "sha256:5ffa845303983e7a0b7ae17636509bc97997d58afeafa72fb141a17b152284cb"}, + {file = "filelock-3.13.3.tar.gz", hash = "sha256:a79895a25bbefdf55d1a2a0a80968f7dbb28edcd6d4234a0afb3f37ecde4b546"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] @@ -777,26 +777,27 @@ smmap = ">=3.0.1,<6" [[package]] name = "gitpython" -version = "3.1.42" +version = "3.1.43" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" files = [ - {file = "GitPython-3.1.42-py3-none-any.whl", hash = "sha256:1bf9cd7c9e7255f77778ea54359e54ac22a72a5b51288c457c881057b7bb9ecd"}, - {file = "GitPython-3.1.42.tar.gz", hash = "sha256:2d99869e0fef71a73cbd242528105af1d6c1b108c60dfabd994bf292f76c3ceb"}, + {file = "GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff"}, + {file = "GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c"}, ] [package.dependencies] gitdb = ">=4.0.1,<5" [package.extras] -test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar"] +doc = ["sphinx (==4.3.2)", "sphinx-autodoc-typehints", "sphinx-rtd-theme", "sphinxcontrib-applehelp (>=1.0.2,<=1.0.4)", "sphinxcontrib-devhelp (==1.0.2)", "sphinxcontrib-htmlhelp (>=2.0.0,<=2.0.1)", "sphinxcontrib-qthelp (==1.0.3)", "sphinxcontrib-serializinghtml (==1.1.5)"] +test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"] [[package]] name = "glfw" version = "2.7.0" description = "A ctypes-based wrapper for GLFW3." -optional = false +optional = true python-versions = "*" files = [ {file = "glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-macosx_10_6_intel.whl", hash = "sha256:bd82849edcceda4e262bd1227afaa74b94f9f0731c1197863cd25c15bfc613fc"}, @@ -879,6 +880,69 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.62.1)"] +[[package]] +name = "gym-aloha" +version = "0.1.0" +description = "A gym environment for ALOHA" +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +dm-control = "1.0.14" +gymnasium = "^0.29.1" +mujoco = "^2.3.7" + +[package.source] +type = "git" +url = "git@github.com:huggingface/gym-aloha.git" +reference = "HEAD" +resolved_reference = "c636f05ba0d1760df94537da84c860be1487e17f" + +[[package]] +name = "gym-pusht" +version = "0.1.0" +description = "A gymnasium environment for PushT." +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +gymnasium = "^0.29.1" +opencv-python = "^4.9.0.80" +pygame = "^2.5.2" +pymunk = "^6.6.0" +scikit-image = "^0.22.0" +shapely = "^2.0.3" + +[package.source] +type = "git" +url = "git@github.com:huggingface/gym-pusht.git" +reference = "HEAD" +resolved_reference = "080d4ce4d8d3140b2fd204ed628bda14dc58ff06" + +[[package]] +name = "gym-xarm" +version = "0.1.0" +description = "A gym environment for xArm" +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +gymnasium = "^0.29.1" +gymnasium-robotics = "^1.2.4" +mujoco = "^2.3.7" + +[package.source] +type = "git" +url = "git@github.com:huggingface/gym-xarm.git" +reference = "HEAD" +resolved_reference = "6a88f7d63833705dfbec4b997bf36cac6b4a448c" + [[package]] name = "gymnasium" version = "0.29.1" @@ -913,7 +977,7 @@ toy-text = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"] name = "gymnasium-robotics" version = "1.2.4" description = "Robotics environments for the Gymnasium repo." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "gymnasium-robotics-1.2.4.tar.gz", hash = "sha256:d304192b066f8b800599dfbe3d9d90bba9b761ee884472bdc4d05968a8bc61cb"}, @@ -1218,7 +1282,7 @@ i18n = ["Babel (>=2.7)"] name = "labmaze" version = "1.0.6" description = "LabMaze: DeepMind Lab's text maze generator." -optional = false +optional = true python-versions = "*" files = [ {file = "labmaze-1.0.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b2ddef976dfd8d992b19cfa6c633f2eba7576d759c2082da534e3f727479a84a"}, @@ -1262,7 +1326,7 @@ setuptools = "!=50.0.0" name = "lazy-loader" version = "0.3" description = "lazy_loader" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "lazy_loader-0.3-py3-none-any.whl", hash = "sha256:1e9e76ee8631e264c62ce10006718e80b2cfc74340d17d1031e0f84af7478554"}, @@ -1305,96 +1369,174 @@ files = [ [[package]] name = "lxml" -version = "5.1.0" +version = "5.2.1" description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." -optional = false +optional = true python-versions = ">=3.6" files = [ - {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"}, - {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"}, - {file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"}, - {file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"}, - {file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2befa20a13f1a75c751f47e00929fb3433d67eb9923c2c0b364de449121f447c"}, - {file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22b7ee4c35f374e2c20337a95502057964d7e35b996b1c667b5c65c567d2252a"}, - {file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:bf8443781533b8d37b295016a4b53c1494fa9a03573c09ca5104550c138d5c05"}, - {file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"}, - {file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"}, - {file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"}, - {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:14deca1460b4b0f6b01f1ddc9557704e8b365f55c63070463f6c18619ebf964f"}, - {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"}, - {file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"}, - {file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"}, - {file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af8920ce4a55ff41167ddbc20077f5698c2e710ad3353d32a07d3264f3a2021e"}, - {file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7cfced4a069003d8913408e10ca8ed092c49a7f6cefee9bb74b6b3e860683b45"}, - {file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9e5ac3437746189a9b4121db2a7b86056ac8786b12e88838696899328fc44bb2"}, - {file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"}, - {file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"}, - {file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"}, - {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9aa543980ab1fbf1720969af1d99095a548ea42e00361e727c58a40832439114"}, - {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"}, - {file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"}, - {file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"}, - {file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16dd953fb719f0ffc5bc067428fc9e88f599e15723a85618c45847c96f11f431"}, - {file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16018f7099245157564d7148165132c70adb272fb5a17c048ba70d9cc542a1a1"}, - {file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:82cd34f1081ae4ea2ede3d52f71b7be313756e99b4b5f829f89b12da552d3aa3"}, - {file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:19a1bc898ae9f06bccb7c3e1dfd73897ecbbd2c96afe9095a6026016e5ca97b8"}, - {file = "lxml-5.1.0-cp312-cp312-win32.whl", hash = "sha256:13521a321a25c641b9ea127ef478b580b5ec82aa2e9fc076c86169d161798b01"}, - {file = "lxml-5.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:1ad17c20e3666c035db502c78b86e58ff6b5991906e55bdbef94977700c72623"}, - {file = "lxml-5.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:24ef5a4631c0b6cceaf2dbca21687e29725b7c4e171f33a8f8ce23c12558ded1"}, - {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d2900b7f5318bc7ad8631d3d40190b95ef2aa8cc59473b73b294e4a55e9f30f"}, - {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:601f4a75797d7a770daed8b42b97cd1bb1ba18bd51a9382077a6a247a12aa38d"}, - {file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4b68c961b5cc402cbd99cca5eb2547e46ce77260eb705f4d117fd9c3f932b95"}, - {file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:afd825e30f8d1f521713a5669b63657bcfe5980a916c95855060048b88e1adb7"}, - {file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:262bc5f512a66b527d026518507e78c2f9c2bd9eb5c8aeeb9f0eb43fcb69dc67"}, - {file = "lxml-5.1.0-cp36-cp36m-win32.whl", hash = "sha256:e856c1c7255c739434489ec9c8aa9cdf5179785d10ff20add308b5d673bed5cd"}, - {file = "lxml-5.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:c7257171bb8d4432fe9d6fdde4d55fdbe663a63636a17f7f9aaba9bcb3153ad7"}, - {file = "lxml-5.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b9e240ae0ba96477682aa87899d94ddec1cc7926f9df29b1dd57b39e797d5ab5"}, - {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a96f02ba1bcd330807fc060ed91d1f7a20853da6dd449e5da4b09bfcc08fdcf5"}, - {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e3898ae2b58eeafedfe99e542a17859017d72d7f6a63de0f04f99c2cb125936"}, - {file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61c5a7edbd7c695e54fca029ceb351fc45cd8860119a0f83e48be44e1c464862"}, - {file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3aeca824b38ca78d9ee2ab82bd9883083d0492d9d17df065ba3b94e88e4d7ee6"}, - {file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"}, - {file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"}, - {file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"}, - {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae15347a88cf8af0949a9872b57a320d2605ae069bcdf047677318bc0bba45b1"}, - {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"}, - {file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"}, - {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"}, - {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"}, - {file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:98f3f020a2b736566c707c8e034945c02aa94e124c24f77ca097c446f81b01f1"}, - {file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"}, - {file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"}, - {file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"}, - {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:52427a7eadc98f9e62cb1368a5079ae826f94f05755d2d567d93ee1bc3ceb354"}, - {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"}, - {file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"}, - {file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"}, - {file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f8b0c78e7aac24979ef09b7f50da871c2de2def043d468c4b41f512d831e912"}, - {file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9bcf86dfc8ff3e992fed847c077bd875d9e0ba2fa25d859c3a0f0f76f07f0c8d"}, - {file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:49a9b4af45e8b925e1cd6f3b15bbba2c81e7dba6dce170c677c9cda547411e14"}, - {file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:280f3edf15c2a967d923bcfb1f8f15337ad36f93525828b40a0f9d6c2ad24890"}, - {file = "lxml-5.1.0-cp39-cp39-win32.whl", hash = "sha256:ed7326563024b6e91fef6b6c7a1a2ff0a71b97793ac33dbbcf38f6005e51ff6e"}, - {file = "lxml-5.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:8d7b4beebb178e9183138f552238f7e6613162a42164233e2bda00cb3afac58f"}, - {file = "lxml-5.1.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9bd0ae7cc2b85320abd5e0abad5ccee5564ed5f0cc90245d2f9a8ef330a8deae"}, - {file = "lxml-5.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8c1d679df4361408b628f42b26a5d62bd3e9ba7f0c0e7969f925021554755aa"}, - {file = "lxml-5.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2ad3a8ce9e8a767131061a22cd28fdffa3cd2dc193f399ff7b81777f3520e372"}, - {file = "lxml-5.1.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:304128394c9c22b6569eba2a6d98392b56fbdfbad58f83ea702530be80d0f9df"}, - {file = "lxml-5.1.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d74fcaf87132ffc0447b3c685a9f862ffb5b43e70ea6beec2fb8057d5d2a1fea"}, - {file = "lxml-5.1.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:8cf5877f7ed384dabfdcc37922c3191bf27e55b498fecece9fd5c2c7aaa34c33"}, - {file = "lxml-5.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:877efb968c3d7eb2dad540b6cabf2f1d3c0fbf4b2d309a3c141f79c7e0061324"}, - {file = "lxml-5.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f14a4fb1c1c402a22e6a341a24c1341b4a3def81b41cd354386dcb795f83897"}, - {file = "lxml-5.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:25663d6e99659544ee8fe1b89b1a8c0aaa5e34b103fab124b17fa958c4a324a6"}, - {file = "lxml-5.1.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8b9f19df998761babaa7f09e6bc169294eefafd6149aaa272081cbddc7ba4ca3"}, - {file = "lxml-5.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e53d7e6a98b64fe54775d23a7c669763451340c3d44ad5e3a3b48a1efbdc96f"}, - {file = "lxml-5.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c3cd1fc1dc7c376c54440aeaaa0dcc803d2126732ff5c6b68ccd619f2e64be4f"}, - {file = "lxml-5.1.0.tar.gz", hash = "sha256:3eea6ed6e6c918e468e693c41ef07f3c3acc310b70ddd9cc72d9ef84bc9564ca"}, + {file = "lxml-5.2.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1f7785f4f789fdb522729ae465adcaa099e2a3441519df750ebdccc481d961a1"}, + {file = "lxml-5.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6cc6ee342fb7fa2471bd9b6d6fdfc78925a697bf5c2bcd0a302e98b0d35bfad3"}, + {file = "lxml-5.2.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:794f04eec78f1d0e35d9e0c36cbbb22e42d370dda1609fb03bcd7aeb458c6377"}, + {file = "lxml-5.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c817d420c60a5183953c783b0547d9eb43b7b344a2c46f69513d5952a78cddf3"}, + {file = "lxml-5.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2213afee476546a7f37c7a9b4ad4d74b1e112a6fafffc9185d6d21f043128c81"}, + {file = "lxml-5.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b070bbe8d3f0f6147689bed981d19bbb33070225373338df755a46893528104a"}, + {file = "lxml-5.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e02c5175f63effbd7c5e590399c118d5db6183bbfe8e0d118bdb5c2d1b48d937"}, + {file = "lxml-5.2.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:3dc773b2861b37b41a6136e0b72a1a44689a9c4c101e0cddb6b854016acc0aa8"}, + {file = "lxml-5.2.1-cp310-cp310-manylinux_2_28_ppc64le.whl", hash = "sha256:d7520db34088c96cc0e0a3ad51a4fd5b401f279ee112aa2b7f8f976d8582606d"}, + {file = "lxml-5.2.1-cp310-cp310-manylinux_2_28_s390x.whl", hash = "sha256:bcbf4af004f98793a95355980764b3d80d47117678118a44a80b721c9913436a"}, + {file = "lxml-5.2.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a2b44bec7adf3e9305ce6cbfa47a4395667e744097faed97abb4728748ba7d47"}, + {file = "lxml-5.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:1c5bb205e9212d0ebddf946bc07e73fa245c864a5f90f341d11ce7b0b854475d"}, + {file = "lxml-5.2.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2c9d147f754b1b0e723e6afb7ba1566ecb162fe4ea657f53d2139bbf894d050a"}, + {file = "lxml-5.2.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:3545039fa4779be2df51d6395e91a810f57122290864918b172d5dc7ca5bb433"}, + {file = "lxml-5.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a91481dbcddf1736c98a80b122afa0f7296eeb80b72344d7f45dc9f781551f56"}, + {file = "lxml-5.2.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2ddfe41ddc81f29a4c44c8ce239eda5ade4e7fc305fb7311759dd6229a080052"}, + {file = "lxml-5.2.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:a7baf9ffc238e4bf401299f50e971a45bfcc10a785522541a6e3179c83eabf0a"}, + {file = "lxml-5.2.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:31e9a882013c2f6bd2f2c974241bf4ba68c85eba943648ce88936d23209a2e01"}, + {file = "lxml-5.2.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:0a15438253b34e6362b2dc41475e7f80de76320f335e70c5528b7148cac253a1"}, + {file = "lxml-5.2.1-cp310-cp310-win32.whl", hash = "sha256:6992030d43b916407c9aa52e9673612ff39a575523c5f4cf72cdef75365709a5"}, + {file = "lxml-5.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:da052e7962ea2d5e5ef5bc0355d55007407087392cf465b7ad84ce5f3e25fe0f"}, + {file = "lxml-5.2.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:70ac664a48aa64e5e635ae5566f5227f2ab7f66a3990d67566d9907edcbbf867"}, + {file = "lxml-5.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1ae67b4e737cddc96c99461d2f75d218bdf7a0c3d3ad5604d1f5e7464a2f9ffe"}, + {file = "lxml-5.2.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f18a5a84e16886898e51ab4b1d43acb3083c39b14c8caeb3589aabff0ee0b270"}, + {file = "lxml-5.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6f2c8372b98208ce609c9e1d707f6918cc118fea4e2c754c9f0812c04ca116d"}, + {file = "lxml-5.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:394ed3924d7a01b5bd9a0d9d946136e1c2f7b3dc337196d99e61740ed4bc6fe1"}, + {file = "lxml-5.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d077bc40a1fe984e1a9931e801e42959a1e6598edc8a3223b061d30fbd26bbc"}, + {file = "lxml-5.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:764b521b75701f60683500d8621841bec41a65eb739b8466000c6fdbc256c240"}, + {file = "lxml-5.2.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:3a6b45da02336895da82b9d472cd274b22dc27a5cea1d4b793874eead23dd14f"}, + {file = "lxml-5.2.1-cp311-cp311-manylinux_2_28_ppc64le.whl", hash = "sha256:5ea7b6766ac2dfe4bcac8b8595107665a18ef01f8c8343f00710b85096d1b53a"}, + {file = "lxml-5.2.1-cp311-cp311-manylinux_2_28_s390x.whl", hash = "sha256:e196a4ff48310ba62e53a8e0f97ca2bca83cdd2fe2934d8b5cb0df0a841b193a"}, + {file = "lxml-5.2.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:200e63525948e325d6a13a76ba2911f927ad399ef64f57898cf7c74e69b71095"}, + {file = "lxml-5.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:dae0ed02f6b075426accbf6b2863c3d0a7eacc1b41fb40f2251d931e50188dad"}, + {file = "lxml-5.2.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:ab31a88a651039a07a3ae327d68ebdd8bc589b16938c09ef3f32a4b809dc96ef"}, + {file = "lxml-5.2.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:df2e6f546c4df14bc81f9498bbc007fbb87669f1bb707c6138878c46b06f6510"}, + {file = "lxml-5.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5dd1537e7cc06efd81371f5d1a992bd5ab156b2b4f88834ca852de4a8ea523fa"}, + {file = "lxml-5.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9b9ec9c9978b708d488bec36b9e4c94d88fd12ccac3e62134a9d17ddba910ea9"}, + {file = "lxml-5.2.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:8e77c69d5892cb5ba71703c4057091e31ccf534bd7f129307a4d084d90d014b8"}, + {file = "lxml-5.2.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:a8d5c70e04aac1eda5c829a26d1f75c6e5286c74743133d9f742cda8e53b9c2f"}, + {file = "lxml-5.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c94e75445b00319c1fad60f3c98b09cd63fe1134a8a953dcd48989ef42318534"}, + {file = "lxml-5.2.1-cp311-cp311-win32.whl", hash = "sha256:4951e4f7a5680a2db62f7f4ab2f84617674d36d2d76a729b9a8be4b59b3659be"}, + {file = "lxml-5.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:5c670c0406bdc845b474b680b9a5456c561c65cf366f8db5a60154088c92d102"}, + {file = "lxml-5.2.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:abc25c3cab9ec7fcd299b9bcb3b8d4a1231877e425c650fa1c7576c5107ab851"}, + {file = "lxml-5.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6935bbf153f9a965f1e07c2649c0849d29832487c52bb4a5c5066031d8b44fd5"}, + {file = "lxml-5.2.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d793bebb202a6000390a5390078e945bbb49855c29c7e4d56a85901326c3b5d9"}, + {file = "lxml-5.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afd5562927cdef7c4f5550374acbc117fd4ecc05b5007bdfa57cc5355864e0a4"}, + {file = "lxml-5.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0e7259016bc4345a31af861fdce942b77c99049d6c2107ca07dc2bba2435c1d9"}, + {file = "lxml-5.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:530e7c04f72002d2f334d5257c8a51bf409db0316feee7c87e4385043be136af"}, + {file = "lxml-5.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59689a75ba8d7ffca577aefd017d08d659d86ad4585ccc73e43edbfc7476781a"}, + {file = "lxml-5.2.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:f9737bf36262046213a28e789cc82d82c6ef19c85a0cf05e75c670a33342ac2c"}, + {file = "lxml-5.2.1-cp312-cp312-manylinux_2_28_ppc64le.whl", hash = "sha256:3a74c4f27167cb95c1d4af1c0b59e88b7f3e0182138db2501c353555f7ec57f4"}, + {file = "lxml-5.2.1-cp312-cp312-manylinux_2_28_s390x.whl", hash = "sha256:68a2610dbe138fa8c5826b3f6d98a7cfc29707b850ddcc3e21910a6fe51f6ca0"}, + {file = "lxml-5.2.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:f0a1bc63a465b6d72569a9bba9f2ef0334c4e03958e043da1920299100bc7c08"}, + {file = "lxml-5.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c2d35a1d047efd68027817b32ab1586c1169e60ca02c65d428ae815b593e65d4"}, + {file = "lxml-5.2.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:79bd05260359170f78b181b59ce871673ed01ba048deef4bf49a36ab3e72e80b"}, + {file = "lxml-5.2.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:865bad62df277c04beed9478fe665b9ef63eb28fe026d5dedcb89b537d2e2ea6"}, + {file = "lxml-5.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:44f6c7caff88d988db017b9b0e4ab04934f11e3e72d478031efc7edcac6c622f"}, + {file = "lxml-5.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:71e97313406ccf55d32cc98a533ee05c61e15d11b99215b237346171c179c0b0"}, + {file = "lxml-5.2.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:057cdc6b86ab732cf361f8b4d8af87cf195a1f6dc5b0ff3de2dced242c2015e0"}, + {file = "lxml-5.2.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:f3bbbc998d42f8e561f347e798b85513ba4da324c2b3f9b7969e9c45b10f6169"}, + {file = "lxml-5.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:491755202eb21a5e350dae00c6d9a17247769c64dcf62d8c788b5c135e179dc4"}, + {file = "lxml-5.2.1-cp312-cp312-win32.whl", hash = "sha256:8de8f9d6caa7f25b204fc861718815d41cbcf27ee8f028c89c882a0cf4ae4134"}, + {file = "lxml-5.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:f2a9efc53d5b714b8df2b4b3e992accf8ce5bbdfe544d74d5c6766c9e1146a3a"}, + {file = "lxml-5.2.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:70a9768e1b9d79edca17890175ba915654ee1725975d69ab64813dd785a2bd5c"}, + {file = "lxml-5.2.1-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c38d7b9a690b090de999835f0443d8aa93ce5f2064035dfc48f27f02b4afc3d0"}, + {file = "lxml-5.2.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5670fb70a828663cc37552a2a85bf2ac38475572b0e9b91283dc09efb52c41d1"}, + {file = "lxml-5.2.1-cp36-cp36m-manylinux_2_28_x86_64.whl", hash = "sha256:958244ad566c3ffc385f47dddde4145088a0ab893504b54b52c041987a8c1863"}, + {file = "lxml-5.2.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b6241d4eee5f89453307c2f2bfa03b50362052ca0af1efecf9fef9a41a22bb4f"}, + {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:2a66bf12fbd4666dd023b6f51223aed3d9f3b40fef06ce404cb75bafd3d89536"}, + {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:9123716666e25b7b71c4e1789ec829ed18663152008b58544d95b008ed9e21e9"}, + {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:0c3f67e2aeda739d1cc0b1102c9a9129f7dc83901226cc24dd72ba275ced4218"}, + {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:5d5792e9b3fb8d16a19f46aa8208987cfeafe082363ee2745ea8b643d9cc5b45"}, + {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:88e22fc0a6684337d25c994381ed8a1580a6f5ebebd5ad41f89f663ff4ec2885"}, + {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:21c2e6b09565ba5b45ae161b438e033a86ad1736b8c838c766146eff8ceffff9"}, + {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_2_s390x.whl", hash = "sha256:afbbdb120d1e78d2ba8064a68058001b871154cc57787031b645c9142b937a62"}, + {file = "lxml-5.2.1-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:627402ad8dea044dde2eccde4370560a2b750ef894c9578e1d4f8ffd54000461"}, + {file = "lxml-5.2.1-cp36-cp36m-win32.whl", hash = "sha256:e89580a581bf478d8dcb97d9cd011d567768e8bc4095f8557b21c4d4c5fea7d0"}, + {file = "lxml-5.2.1-cp36-cp36m-win_amd64.whl", hash = "sha256:59565f10607c244bc4c05c0c5fa0c190c990996e0c719d05deec7030c2aa8289"}, + {file = "lxml-5.2.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:857500f88b17a6479202ff5fe5f580fc3404922cd02ab3716197adf1ef628029"}, + {file = "lxml-5.2.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:56c22432809085b3f3ae04e6e7bdd36883d7258fcd90e53ba7b2e463efc7a6af"}, + {file = "lxml-5.2.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a55ee573116ba208932e2d1a037cc4b10d2c1cb264ced2184d00b18ce585b2c0"}, + {file = "lxml-5.2.1-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:6cf58416653c5901e12624e4013708b6e11142956e7f35e7a83f1ab02f3fe456"}, + {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:64c2baa7774bc22dd4474248ba16fe1a7f611c13ac6123408694d4cc93d66dbd"}, + {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:74b28c6334cca4dd704e8004cba1955af0b778cf449142e581e404bd211fb619"}, + {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:7221d49259aa1e5a8f00d3d28b1e0b76031655ca74bb287123ef56c3db92f213"}, + {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3dbe858ee582cbb2c6294dc85f55b5f19c918c2597855e950f34b660f1a5ede6"}, + {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:04ab5415bf6c86e0518d57240a96c4d1fcfc3cb370bb2ac2a732b67f579e5a04"}, + {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:6ab833e4735a7e5533711a6ea2df26459b96f9eec36d23f74cafe03631647c41"}, + {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:f443cdef978430887ed55112b491f670bba6462cea7a7742ff8f14b7abb98d75"}, + {file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:9e2addd2d1866fe112bc6f80117bcc6bc25191c5ed1bfbcf9f1386a884252ae8"}, + {file = "lxml-5.2.1-cp37-cp37m-win32.whl", hash = "sha256:f51969bac61441fd31f028d7b3b45962f3ecebf691a510495e5d2cd8c8092dbd"}, + {file = "lxml-5.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:b0b58fbfa1bf7367dde8a557994e3b1637294be6cf2169810375caf8571a085c"}, + {file = "lxml-5.2.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3e183c6e3298a2ed5af9d7a356ea823bccaab4ec2349dc9ed83999fd289d14d5"}, + {file = "lxml-5.2.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:804f74efe22b6a227306dd890eecc4f8c59ff25ca35f1f14e7482bbce96ef10b"}, + {file = "lxml-5.2.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:08802f0c56ed150cc6885ae0788a321b73505d2263ee56dad84d200cab11c07a"}, + {file = "lxml-5.2.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8c09ed18ecb4ebf23e02b8e7a22a05d6411911e6fabef3a36e4f371f4f2585"}, + {file = "lxml-5.2.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3d30321949861404323c50aebeb1943461a67cd51d4200ab02babc58bd06a86"}, + {file = "lxml-5.2.1-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:b560e3aa4b1d49e0e6c847d72665384db35b2f5d45f8e6a5c0072e0283430533"}, + {file = "lxml-5.2.1-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:058a1308914f20784c9f4674036527e7c04f7be6fb60f5d61353545aa7fcb739"}, + {file = "lxml-5.2.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:adfb84ca6b87e06bc6b146dc7da7623395db1e31621c4785ad0658c5028b37d7"}, + {file = "lxml-5.2.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:417d14450f06d51f363e41cace6488519038f940676ce9664b34ebf5653433a5"}, + {file = "lxml-5.2.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a2dfe7e2473f9b59496247aad6e23b405ddf2e12ef0765677b0081c02d6c2c0b"}, + {file = "lxml-5.2.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bf2e2458345d9bffb0d9ec16557d8858c9c88d2d11fed53998512504cd9df49b"}, + {file = "lxml-5.2.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:58278b29cb89f3e43ff3e0c756abbd1518f3ee6adad9e35b51fb101c1c1daaec"}, + {file = "lxml-5.2.1-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:64641a6068a16201366476731301441ce93457eb8452056f570133a6ceb15fca"}, + {file = "lxml-5.2.1-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:78bfa756eab503673991bdcf464917ef7845a964903d3302c5f68417ecdc948c"}, + {file = "lxml-5.2.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:11a04306fcba10cd9637e669fd73aa274c1c09ca64af79c041aa820ea992b637"}, + {file = "lxml-5.2.1-cp38-cp38-win32.whl", hash = "sha256:66bc5eb8a323ed9894f8fa0ee6cb3e3fb2403d99aee635078fd19a8bc7a5a5da"}, + {file = "lxml-5.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:9676bfc686fa6a3fa10cd4ae6b76cae8be26eb5ec6811d2a325636c460da1806"}, + {file = "lxml-5.2.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cf22b41fdae514ee2f1691b6c3cdeae666d8b7fa9434de445f12bbeee0cf48dd"}, + {file = "lxml-5.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ec42088248c596dbd61d4ae8a5b004f97a4d91a9fd286f632e42e60b706718d7"}, + {file = "lxml-5.2.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd53553ddad4a9c2f1f022756ae64abe16da1feb497edf4d9f87f99ec7cf86bd"}, + {file = "lxml-5.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feaa45c0eae424d3e90d78823f3828e7dc42a42f21ed420db98da2c4ecf0a2cb"}, + {file = "lxml-5.2.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddc678fb4c7e30cf830a2b5a8d869538bc55b28d6c68544d09c7d0d8f17694dc"}, + {file = "lxml-5.2.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:853e074d4931dbcba7480d4dcab23d5c56bd9607f92825ab80ee2bd916edea53"}, + {file = "lxml-5.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc4691d60512798304acb9207987e7b2b7c44627ea88b9d77489bbe3e6cc3bd4"}, + {file = "lxml-5.2.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:beb72935a941965c52990f3a32d7f07ce869fe21c6af8b34bf6a277b33a345d3"}, + {file = "lxml-5.2.1-cp39-cp39-manylinux_2_28_ppc64le.whl", hash = "sha256:6588c459c5627fefa30139be4d2e28a2c2a1d0d1c265aad2ba1935a7863a4913"}, + {file = "lxml-5.2.1-cp39-cp39-manylinux_2_28_s390x.whl", hash = "sha256:588008b8497667f1ddca7c99f2f85ce8511f8f7871b4a06ceede68ab62dff64b"}, + {file = "lxml-5.2.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b6787b643356111dfd4032b5bffe26d2f8331556ecb79e15dacb9275da02866e"}, + {file = "lxml-5.2.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7c17b64b0a6ef4e5affae6a3724010a7a66bda48a62cfe0674dabd46642e8b54"}, + {file = "lxml-5.2.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:27aa20d45c2e0b8cd05da6d4759649170e8dfc4f4e5ef33a34d06f2d79075d57"}, + {file = "lxml-5.2.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:d4f2cc7060dc3646632d7f15fe68e2fa98f58e35dd5666cd525f3b35d3fed7f8"}, + {file = "lxml-5.2.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff46d772d5f6f73564979cd77a4fffe55c916a05f3cb70e7c9c0590059fb29ef"}, + {file = "lxml-5.2.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:96323338e6c14e958d775700ec8a88346014a85e5de73ac7967db0367582049b"}, + {file = "lxml-5.2.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:52421b41ac99e9d91934e4d0d0fe7da9f02bfa7536bb4431b4c05c906c8c6919"}, + {file = "lxml-5.2.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:7a7efd5b6d3e30d81ec68ab8a88252d7c7c6f13aaa875009fe3097eb4e30b84c"}, + {file = "lxml-5.2.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:0ed777c1e8c99b63037b91f9d73a6aad20fd035d77ac84afcc205225f8f41188"}, + {file = "lxml-5.2.1-cp39-cp39-win32.whl", hash = "sha256:644df54d729ef810dcd0f7732e50e5ad1bd0a135278ed8d6bcb06f33b6b6f708"}, + {file = "lxml-5.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:9ca66b8e90daca431b7ca1408cae085d025326570e57749695d6a01454790e95"}, + {file = "lxml-5.2.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9b0ff53900566bc6325ecde9181d89afadc59c5ffa39bddf084aaedfe3b06a11"}, + {file = "lxml-5.2.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd6037392f2d57793ab98d9e26798f44b8b4da2f2464388588f48ac52c489ea1"}, + {file = "lxml-5.2.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b9c07e7a45bb64e21df4b6aa623cb8ba214dfb47d2027d90eac197329bb5e94"}, + {file = "lxml-5.2.1-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:3249cc2989d9090eeac5467e50e9ec2d40704fea9ab72f36b034ea34ee65ca98"}, + {file = "lxml-5.2.1-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:f42038016852ae51b4088b2862126535cc4fc85802bfe30dea3500fdfaf1864e"}, + {file = "lxml-5.2.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:533658f8fbf056b70e434dff7e7aa611bcacb33e01f75de7f821810e48d1bb66"}, + {file = "lxml-5.2.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:622020d4521e22fb371e15f580d153134bfb68d6a429d1342a25f051ec72df1c"}, + {file = "lxml-5.2.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:efa7b51824aa0ee957ccd5a741c73e6851de55f40d807f08069eb4c5a26b2baa"}, + {file = "lxml-5.2.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c6ad0fbf105f6bcc9300c00010a2ffa44ea6f555df1a2ad95c88f5656104817"}, + {file = "lxml-5.2.1-pp37-pypy37_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:e233db59c8f76630c512ab4a4daf5a5986da5c3d5b44b8e9fc742f2a24dbd460"}, + {file = "lxml-5.2.1-pp37-pypy37_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:6a014510830df1475176466b6087fc0c08b47a36714823e58d8b8d7709132a96"}, + {file = "lxml-5.2.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:d38c8f50ecf57f0463399569aa388b232cf1a2ffb8f0a9a5412d0db57e054860"}, + {file = "lxml-5.2.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5aea8212fb823e006b995c4dda533edcf98a893d941f173f6c9506126188860d"}, + {file = "lxml-5.2.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ff097ae562e637409b429a7ac958a20aab237a0378c42dabaa1e3abf2f896e5f"}, + {file = "lxml-5.2.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f5d65c39f16717a47c36c756af0fb36144069c4718824b7533f803ecdf91138"}, + {file = "lxml-5.2.1-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:3d0c3dd24bb4605439bf91068598d00c6370684f8de4a67c2992683f6c309d6b"}, + {file = "lxml-5.2.1-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e32be23d538753a8adb6c85bd539f5fd3b15cb987404327c569dfc5fd8366e85"}, + {file = "lxml-5.2.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:cc518cea79fd1e2f6c90baafa28906d4309d24f3a63e801d855e7424c5b34144"}, + {file = "lxml-5.2.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a0af35bd8ebf84888373630f73f24e86bf016642fb8576fba49d3d6b560b7cbc"}, + {file = "lxml-5.2.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8aca2e3a72f37bfc7b14ba96d4056244001ddcc18382bd0daa087fd2e68a354"}, + {file = "lxml-5.2.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ca1e8188b26a819387b29c3895c47a5e618708fe6f787f3b1a471de2c4a94d9"}, + {file = "lxml-5.2.1-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c8ba129e6d3b0136a0f50345b2cb3db53f6bda5dd8c7f5d83fbccba97fb5dcb5"}, + {file = "lxml-5.2.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e998e304036198b4f6914e6a1e2b6f925208a20e2042563d9734881150c6c246"}, + {file = "lxml-5.2.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:d3be9b2076112e51b323bdf6d5a7f8a798de55fb8d95fcb64bd179460cdc0704"}, + {file = "lxml-5.2.1.tar.gz", hash = "sha256:3f7765e69bbce0906a7c74d5fe46d2c7a7596147318dbc08e4a2431f3060e306"}, ] [package.extras] cssselect = ["cssselect (>=0.7)"] +html-clean = ["lxml-html-clean"] html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] -source = ["Cython (>=3.0.7)"] +source = ["Cython (>=3.0.10)"] [[package]] name = "markdown" @@ -1525,7 +1667,7 @@ tests = ["pytest (>=4.6)"] name = "mujoco" version = "2.3.7" description = "MuJoCo Physics Simulator" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"}, @@ -1563,20 +1705,20 @@ pyopengl = "*" [[package]] name = "networkx" -version = "3.2.1" +version = "3.3" description = "Python package for creating and manipulating graphs and networks" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" files = [ - {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, - {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"}, + {file = "networkx-3.3-py3-none-any.whl", hash = "sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2"}, + {file = "networkx-3.3.tar.gz", hash = "sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9"}, ] [package.extras] -default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] -developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] -doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] -extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] +default = ["matplotlib (>=3.6)", "numpy (>=1.23)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] +developer = ["changelist (==0.5)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] +doc = ["myst-nb (>=1.0)", "numpydoc (>=1.7)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=2.0)", "pygraphviz (>=1.12)", "sympy (>=1.10)"] test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] [[package]] @@ -1833,13 +1975,13 @@ files = [ [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.4.99" +version = "12.4.127" description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"}, - {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"}, + {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, + {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, ] [[package]] @@ -1980,7 +2122,7 @@ xml = ["lxml (>=4.9.2)"] name = "pettingzoo" version = "1.24.3" description = "Gymnasium for multi-agent reinforcement learning." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "pettingzoo-1.24.3-py3-none-any.whl", hash = "sha256:23ed90517d2e8a7098bdaf5e31234b3a7f7b73ca578d70d1ca7b9d0cb0e37982"}, @@ -2003,79 +2145,80 @@ testing = ["AutoROM", "pre-commit", "pynput", "pytest", "pytest-cov", "pytest-ma [[package]] name = "pillow" -version = "10.2.0" +version = "10.3.0" description = "Python Imaging Library (Fork)" optional = false python-versions = ">=3.8" files = [ - {file = "pillow-10.2.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:7823bdd049099efa16e4246bdf15e5a13dbb18a51b68fa06d6c1d4d8b99a796e"}, - {file = "pillow-10.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:83b2021f2ade7d1ed556bc50a399127d7fb245e725aa0113ebd05cfe88aaf588"}, - {file = "pillow-10.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6fad5ff2f13d69b7e74ce5b4ecd12cc0ec530fcee76356cac6742785ff71c452"}, - {file = "pillow-10.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da2b52b37dad6d9ec64e653637a096905b258d2fc2b984c41ae7d08b938a67e4"}, - {file = "pillow-10.2.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:47c0995fc4e7f79b5cfcab1fc437ff2890b770440f7696a3ba065ee0fd496563"}, - {file = "pillow-10.2.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:322bdf3c9b556e9ffb18f93462e5f749d3444ce081290352c6070d014c93feb2"}, - {file = "pillow-10.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:51f1a1bffc50e2e9492e87d8e09a17c5eea8409cda8d3f277eb6edc82813c17c"}, - {file = "pillow-10.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:69ffdd6120a4737710a9eee73e1d2e37db89b620f702754b8f6e62594471dee0"}, - {file = "pillow-10.2.0-cp310-cp310-win32.whl", hash = "sha256:c6dafac9e0f2b3c78df97e79af707cdc5ef8e88208d686a4847bab8266870023"}, - {file = "pillow-10.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:aebb6044806f2e16ecc07b2a2637ee1ef67a11840a66752751714a0d924adf72"}, - {file = "pillow-10.2.0-cp310-cp310-win_arm64.whl", hash = "sha256:7049e301399273a0136ff39b84c3678e314f2158f50f517bc50285fb5ec847ad"}, - {file = "pillow-10.2.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:35bb52c37f256f662abdfa49d2dfa6ce5d93281d323a9af377a120e89a9eafb5"}, - {file = "pillow-10.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9c23f307202661071d94b5e384e1e1dc7dfb972a28a2310e4ee16103e66ddb67"}, - {file = "pillow-10.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:773efe0603db30c281521a7c0214cad7836c03b8ccff897beae9b47c0b657d61"}, - {file = "pillow-10.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11fa2e5984b949b0dd6d7a94d967743d87c577ff0b83392f17cb3990d0d2fd6e"}, - {file = "pillow-10.2.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:716d30ed977be8b37d3ef185fecb9e5a1d62d110dfbdcd1e2a122ab46fddb03f"}, - {file = "pillow-10.2.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a086c2af425c5f62a65e12fbf385f7c9fcb8f107d0849dba5839461a129cf311"}, - {file = "pillow-10.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c8de2789052ed501dd829e9cae8d3dcce7acb4777ea4a479c14521c942d395b1"}, - {file = "pillow-10.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:609448742444d9290fd687940ac0b57fb35e6fd92bdb65386e08e99af60bf757"}, - {file = "pillow-10.2.0-cp311-cp311-win32.whl", hash = "sha256:823ef7a27cf86df6597fa0671066c1b596f69eba53efa3d1e1cb8b30f3533068"}, - {file = "pillow-10.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:1da3b2703afd040cf65ec97efea81cfba59cdbed9c11d8efc5ab09df9509fc56"}, - {file = "pillow-10.2.0-cp311-cp311-win_arm64.whl", hash = "sha256:edca80cbfb2b68d7b56930b84a0e45ae1694aeba0541f798e908a49d66b837f1"}, - {file = "pillow-10.2.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:1b5e1b74d1bd1b78bc3477528919414874748dd363e6272efd5abf7654e68bef"}, - {file = "pillow-10.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0eae2073305f451d8ecacb5474997c08569fb4eb4ac231ffa4ad7d342fdc25ac"}, - {file = "pillow-10.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7c2286c23cd350b80d2fc9d424fc797575fb16f854b831d16fd47ceec078f2c"}, - {file = "pillow-10.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e23412b5c41e58cec602f1135c57dfcf15482013ce6e5f093a86db69646a5aa"}, - {file = "pillow-10.2.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:52a50aa3fb3acb9cf7213573ef55d31d6eca37f5709c69e6858fe3bc04a5c2a2"}, - {file = "pillow-10.2.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:127cee571038f252a552760076407f9cff79761c3d436a12af6000cd182a9d04"}, - {file = "pillow-10.2.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:8d12251f02d69d8310b046e82572ed486685c38f02176bd08baf216746eb947f"}, - {file = "pillow-10.2.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:54f1852cd531aa981bc0965b7d609f5f6cc8ce8c41b1139f6ed6b3c54ab82bfb"}, - {file = "pillow-10.2.0-cp312-cp312-win32.whl", hash = "sha256:257d8788df5ca62c980314053197f4d46eefedf4e6175bc9412f14412ec4ea2f"}, - {file = "pillow-10.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:154e939c5f0053a383de4fd3d3da48d9427a7e985f58af8e94d0b3c9fcfcf4f9"}, - {file = "pillow-10.2.0-cp312-cp312-win_arm64.whl", hash = "sha256:f379abd2f1e3dddb2b61bc67977a6b5a0a3f7485538bcc6f39ec76163891ee48"}, - {file = "pillow-10.2.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:8373c6c251f7ef8bda6675dd6d2b3a0fcc31edf1201266b5cf608b62a37407f9"}, - {file = "pillow-10.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:870ea1ada0899fd0b79643990809323b389d4d1d46c192f97342eeb6ee0b8483"}, - {file = "pillow-10.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4b6b1e20608493548b1f32bce8cca185bf0480983890403d3b8753e44077129"}, - {file = "pillow-10.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3031709084b6e7852d00479fd1d310b07d0ba82765f973b543c8af5061cf990e"}, - {file = "pillow-10.2.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:3ff074fc97dd4e80543a3e91f69d58889baf2002b6be64347ea8cf5533188213"}, - {file = "pillow-10.2.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:cb4c38abeef13c61d6916f264d4845fab99d7b711be96c326b84df9e3e0ff62d"}, - {file = "pillow-10.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b1b3020d90c2d8e1dae29cf3ce54f8094f7938460fb5ce8bc5c01450b01fbaf6"}, - {file = "pillow-10.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:170aeb00224ab3dc54230c797f8404507240dd868cf52066f66a41b33169bdbe"}, - {file = "pillow-10.2.0-cp38-cp38-win32.whl", hash = "sha256:c4225f5220f46b2fde568c74fca27ae9771536c2e29d7c04f4fb62c83275ac4e"}, - {file = "pillow-10.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:0689b5a8c5288bc0504d9fcee48f61a6a586b9b98514d7d29b840143d6734f39"}, - {file = "pillow-10.2.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:b792a349405fbc0163190fde0dc7b3fef3c9268292586cf5645598b48e63dc67"}, - {file = "pillow-10.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c570f24be1e468e3f0ce7ef56a89a60f0e05b30a3669a459e419c6eac2c35364"}, - {file = "pillow-10.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8ecd059fdaf60c1963c58ceb8997b32e9dc1b911f5da5307aab614f1ce5c2fb"}, - {file = "pillow-10.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c365fd1703040de1ec284b176d6af5abe21b427cb3a5ff68e0759e1e313a5e7e"}, - {file = "pillow-10.2.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:70c61d4c475835a19b3a5aa42492409878bbca7438554a1f89d20d58a7c75c01"}, - {file = "pillow-10.2.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b6f491cdf80ae540738859d9766783e3b3c8e5bd37f5dfa0b76abdecc5081f13"}, - {file = "pillow-10.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d189550615b4948f45252d7f005e53c2040cea1af5b60d6f79491a6e147eef7"}, - {file = "pillow-10.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:49d9ba1ed0ef3e061088cd1e7538a0759aab559e2e0a80a36f9fd9d8c0c21591"}, - {file = "pillow-10.2.0-cp39-cp39-win32.whl", hash = "sha256:babf5acfede515f176833ed6028754cbcd0d206f7f614ea3447d67c33be12516"}, - {file = "pillow-10.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:0304004f8067386b477d20a518b50f3fa658a28d44e4116970abfcd94fac34a8"}, - {file = "pillow-10.2.0-cp39-cp39-win_arm64.whl", hash = "sha256:0fb3e7fc88a14eacd303e90481ad983fd5b69c761e9e6ef94c983f91025da869"}, - {file = "pillow-10.2.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:322209c642aabdd6207517e9739c704dc9f9db943015535783239022002f054a"}, - {file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eedd52442c0a5ff4f887fab0c1c0bb164d8635b32c894bc1faf4c618dd89df2"}, - {file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb28c753fd5eb3dd859b4ee95de66cc62af91bcff5db5f2571d32a520baf1f04"}, - {file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:33870dc4653c5017bf4c8873e5488d8f8d5f8935e2f1fb9a2208c47cdd66efd2"}, - {file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:3c31822339516fb3c82d03f30e22b1d038da87ef27b6a78c9549888f8ceda39a"}, - {file = "pillow-10.2.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a2b56ba36e05f973d450582fb015594aaa78834fefe8dfb8fcd79b93e64ba4c6"}, - {file = "pillow-10.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:d8e6aeb9201e655354b3ad049cb77d19813ad4ece0df1249d3c793de3774f8c7"}, - {file = "pillow-10.2.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:2247178effb34a77c11c0e8ac355c7a741ceca0a732b27bf11e747bbc950722f"}, - {file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15587643b9e5eb26c48e49a7b33659790d28f190fc514a322d55da2fb5c2950e"}, - {file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753cd8f2086b2b80180d9b3010dd4ed147efc167c90d3bf593fe2af21265e5a5"}, - {file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:7c8f97e8e7a9009bcacbe3766a36175056c12f9a44e6e6f2d5caad06dcfbf03b"}, - {file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d1b35bcd6c5543b9cb547dee3150c93008f8dd0f1fef78fc0cd2b141c5baf58a"}, - {file = "pillow-10.2.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fe4c15f6c9285dc54ce6553a3ce908ed37c8f3825b5a51a15c91442bb955b868"}, - {file = "pillow-10.2.0.tar.gz", hash = "sha256:e87f0b2c78157e12d7686b27d63c070fd65d994e8ddae6f328e0dcf4a0cd007e"}, + {file = "pillow-10.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45"}, + {file = "pillow-10.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475"}, + {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf"}, + {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3"}, + {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5"}, + {file = "pillow-10.3.0-cp310-cp310-win32.whl", hash = "sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2"}, + {file = "pillow-10.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f"}, + {file = "pillow-10.3.0-cp310-cp310-win_arm64.whl", hash = "sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b"}, + {file = "pillow-10.3.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795"}, + {file = "pillow-10.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451"}, + {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd"}, + {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad"}, + {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c"}, + {file = "pillow-10.3.0-cp311-cp311-win32.whl", hash = "sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09"}, + {file = "pillow-10.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d"}, + {file = "pillow-10.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f"}, + {file = "pillow-10.3.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84"}, + {file = "pillow-10.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462"}, + {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a"}, + {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef"}, + {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3"}, + {file = "pillow-10.3.0-cp312-cp312-win32.whl", hash = "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d"}, + {file = "pillow-10.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b"}, + {file = "pillow-10.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a"}, + {file = "pillow-10.3.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b"}, + {file = "pillow-10.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d"}, + {file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd"}, + {file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d"}, + {file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3"}, + {file = "pillow-10.3.0-cp38-cp38-win32.whl", hash = "sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b"}, + {file = "pillow-10.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999"}, + {file = "pillow-10.3.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936"}, + {file = "pillow-10.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57"}, + {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8"}, + {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9"}, + {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb"}, + {file = "pillow-10.3.0-cp39-cp39-win32.whl", hash = "sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572"}, + {file = "pillow-10.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb"}, + {file = "pillow-10.3.0-cp39-cp39-win_arm64.whl", hash = "sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3"}, + {file = "pillow-10.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a"}, + {file = "pillow-10.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591"}, + {file = "pillow-10.3.0.tar.gz", hash = "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d"}, ] [package.extras] @@ -2118,13 +2261,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pre-commit" -version = "3.6.2" +version = "3.7.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." optional = false python-versions = ">=3.9" files = [ - {file = "pre_commit-3.6.2-py2.py3-none-any.whl", hash = "sha256:ba637c2d7a670c10daedc059f5c49b5bd0aadbccfcd7ec15592cf9665117532c"}, - {file = "pre_commit-3.6.2.tar.gz", hash = "sha256:c3ef34f463045c88658c5b99f38c1e297abdcc0ff13f98d3370055fbbfabc67e"}, + {file = "pre_commit-3.7.0-py2.py3-none-any.whl", hash = "sha256:5eae9e10c2b5ac51577c3452ec0a490455c45a0533f7960f993a0d01e59decab"}, + {file = "pre_commit-3.7.0.tar.gz", hash = "sha256:e209d61b8acdcf742404408531f0c37d49d2c734fd7cff2d6076083d191cb060"}, ] [package.dependencies] @@ -2198,13 +2341,13 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] [[package]] name = "pycparser" -version = "2.21" +version = "2.22" description = "C parser in Python" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +python-versions = ">=3.8" files = [ - {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, - {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, + {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, + {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, ] [[package]] @@ -2348,7 +2491,7 @@ dev = ["aafigure", "matplotlib", "pygame", "pyglet (<2.0.0)", "sphinx", "wheel"] name = "pyopengl" version = "3.1.7" description = "Standard OpenGL bindings for Python" -optional = false +optional = true python-versions = "*" files = [ {file = "PyOpenGL-3.1.7-py3-none-any.whl", hash = "sha256:a6ab19cf290df6101aaf7470843a9c46207789855746399d0af92521a0a92b7a"}, @@ -2359,7 +2502,7 @@ files = [ name = "pyparsing" version = "3.1.2" description = "pyparsing module - Classes and methods to define and execute parsing grammars" -optional = false +optional = true python-versions = ">=3.6.8" files = [ {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, @@ -2790,7 +2933,7 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"] name = "scikit-image" version = "0.22.0" description = "Image processing in Python" -optional = false +optional = true python-versions = ">=3.9" files = [ {file = "scikit_image-0.22.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:74ec5c1d4693506842cc7c9487c89d8fc32aed064e9363def7af08b8f8cbb31d"}, @@ -2836,55 +2979,55 @@ test = ["asv", "matplotlib (>=3.5)", "numpydoc (>=1.5)", "pooch (>=1.6.0)", "pyt [[package]] name = "scipy" -version = "1.12.0" +version = "1.13.0" description = "Fundamental algorithms for scientific computing in Python" -optional = false +optional = true python-versions = ">=3.9" files = [ - {file = "scipy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:78e4402e140879387187f7f25d91cc592b3501a2e51dfb320f48dfb73565f10b"}, - {file = "scipy-1.12.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f5f00ebaf8de24d14b8449981a2842d404152774c1a1d880c901bf454cb8e2a1"}, - {file = "scipy-1.12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e53958531a7c695ff66c2e7bb7b79560ffdc562e2051644c5576c39ff8efb563"}, - {file = "scipy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e32847e08da8d895ce09d108a494d9eb78974cf6de23063f93306a3e419960c"}, - {file = "scipy-1.12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4c1020cad92772bf44b8e4cdabc1df5d87376cb219742549ef69fc9fd86282dd"}, - {file = "scipy-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:75ea2a144096b5e39402e2ff53a36fecfd3b960d786b7efd3c180e29c39e53f2"}, - {file = "scipy-1.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:408c68423f9de16cb9e602528be4ce0d6312b05001f3de61fe9ec8b1263cad08"}, - {file = "scipy-1.12.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5adfad5dbf0163397beb4aca679187d24aec085343755fcdbdeb32b3679f254c"}, - {file = "scipy-1.12.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3003652496f6e7c387b1cf63f4bb720951cfa18907e998ea551e6de51a04467"}, - {file = "scipy-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b8066bce124ee5531d12a74b617d9ac0ea59245246410e19bca549656d9a40a"}, - {file = "scipy-1.12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8bee4993817e204d761dba10dbab0774ba5a8612e57e81319ea04d84945375ba"}, - {file = "scipy-1.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:a24024d45ce9a675c1fb8494e8e5244efea1c7a09c60beb1eeb80373d0fecc70"}, - {file = "scipy-1.12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e7e76cc48638228212c747ada851ef355c2bb5e7f939e10952bc504c11f4e372"}, - {file = "scipy-1.12.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f7ce148dffcd64ade37b2df9315541f9adad6efcaa86866ee7dd5db0c8f041c3"}, - {file = "scipy-1.12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c39f92041f490422924dfdb782527a4abddf4707616e07b021de33467f917bc"}, - {file = "scipy-1.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a7ebda398f86e56178c2fa94cad15bf457a218a54a35c2a7b4490b9f9cb2676c"}, - {file = "scipy-1.12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:95e5c750d55cf518c398a8240571b0e0782c2d5a703250872f36eaf737751338"}, - {file = "scipy-1.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:e646d8571804a304e1da01040d21577685ce8e2db08ac58e543eaca063453e1c"}, - {file = "scipy-1.12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:913d6e7956c3a671de3b05ccb66b11bc293f56bfdef040583a7221d9e22a2e35"}, - {file = "scipy-1.12.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba1b0c7256ad75401c73e4b3cf09d1f176e9bd4248f0d3112170fb2ec4db067"}, - {file = "scipy-1.12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:730badef9b827b368f351eacae2e82da414e13cf8bd5051b4bdfd720271a5371"}, - {file = "scipy-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6546dc2c11a9df6926afcbdd8a3edec28566e4e785b915e849348c6dd9f3f490"}, - {file = "scipy-1.12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:196ebad3a4882081f62a5bf4aeb7326aa34b110e533aab23e4374fcccb0890dc"}, - {file = "scipy-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:b360f1b6b2f742781299514e99ff560d1fe9bd1bff2712894b52abe528d1fd1e"}, - {file = "scipy-1.12.0.tar.gz", hash = "sha256:4bf5abab8a36d20193c698b0f1fc282c1d083c94723902c447e5d2f1780936a3"}, + {file = "scipy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d"}, + {file = "scipy-1.13.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e"}, + {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922"}, + {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4"}, + {file = "scipy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9"}, + {file = "scipy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd"}, + {file = "scipy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa"}, + {file = "scipy-1.13.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5"}, + {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7"}, + {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d"}, + {file = "scipy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c"}, + {file = "scipy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6"}, + {file = "scipy-1.13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b"}, + {file = "scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551"}, + {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a"}, + {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42"}, + {file = "scipy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820"}, + {file = "scipy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21"}, + {file = "scipy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5407708195cb38d70fd2d6bb04b1b9dd5c92297d86e9f9daae1576bd9e06f602"}, + {file = "scipy-1.13.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:ac38c4c92951ac0f729c4c48c9e13eb3675d9986cc0c83943784d7390d540c78"}, + {file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09c74543c4fbeb67af6ce457f6a6a28e5d3739a87f62412e4a16e46f164f0ae5"}, + {file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28e286bf9ac422d6beb559bc61312c348ca9b0f0dae0d7c5afde7f722d6ea13d"}, + {file = "scipy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33fde20efc380bd23a78a4d26d59fc8704e9b5fd9b08841693eb46716ba13d86"}, + {file = "scipy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:45c08bec71d3546d606989ba6e7daa6f0992918171e2a6f7fbedfa7361c2de1e"}, + {file = "scipy-1.13.0.tar.gz", hash = "sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e"}, ] [package.dependencies] -numpy = ">=1.22.4,<1.29.0" +numpy = ">=1.22.4,<2.3" [package.extras] -dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] -doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] -test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] +test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] [[package]] name = "sentry-sdk" -version = "1.43.0" +version = "1.44.1" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.43.0.tar.gz", hash = "sha256:41df73af89d22921d8733714fb0fc5586c3461907e06688e6537d01a27e0e0f6"}, - {file = "sentry_sdk-1.43.0-py2.py3-none-any.whl", hash = "sha256:8d768724839ca18d7b4c7463ef7528c40b7aa2bfbf7fe554d5f9a7c044acfd36"}, + {file = "sentry-sdk-1.44.1.tar.gz", hash = "sha256:24e6a53eeabffd2f95d952aa35ca52f0f4201d17f820ac9d3ff7244c665aaf68"}, + {file = "sentry_sdk-1.44.1-py2.py3-none-any.whl", hash = "sha256:5f75eb91d8ab6037c754a87b8501cc581b2827e923682f593bed3539ce5b3999"}, ] [package.dependencies] @@ -3043,7 +3186,7 @@ testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jar name = "shapely" version = "2.0.3" description = "Manipulation and analysis of geometric objects" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "shapely-2.0.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:af7e9abe180b189431b0f490638281b43b84a33a960620e6b2e8d3e3458b61a1"}, @@ -3192,31 +3335,6 @@ numpy = "*" packaging = "*" protobuf = ">=3.20" -[[package]] -name = "tensordict" -version = "0.4.0+b4c91e8" -description = "" -optional = false -python-versions = "*" -files = [] -develop = false - -[package.dependencies] -cloudpickle = "*" -numpy = "*" -torch = ">=2.1.0" - -[package.extras] -checkpointing = ["torchsnapshot-nightly"] -h5 = ["h5py (>=3.8)"] -tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures", "pyyaml"] - -[package.source] -type = "git" -url = "https://github.com/pytorch/tensordict" -reference = "HEAD" -resolved_reference = "b4c91e8828c538ca0a50d8383fd99311a9afb078" - [[package]] name = "termcolor" version = "2.4.0" @@ -3235,7 +3353,7 @@ tests = ["pytest", "pytest-cov"] name = "tifffile" version = "2024.2.12" description = "Read and write TIFF files" -optional = false +optional = true python-versions = ">=3.9" files = [ {file = "tifffile-2024.2.12-py3-none-any.whl", hash = "sha256:870998f82fbc94ff7c3528884c1b0ae54863504ff51dbebea431ac3fa8fb7c21"}, @@ -3261,36 +3379,36 @@ files = [ [[package]] name = "torch" -version = "2.2.1" +version = "2.2.2" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.2.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:8d3bad336dd2c93c6bcb3268e8e9876185bda50ebde325ef211fb565c7d15273"}, - {file = "torch-2.2.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5297f13370fdaca05959134b26a06a7f232ae254bf2e11a50eddec62525c9006"}, - {file = "torch-2.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:5f5dee8433798888ca1415055f5e3faf28a3bad660e4c29e1014acd3275ab11a"}, - {file = "torch-2.2.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:b6d78338acabf1fb2e88bf4559d837d30230cf9c3e4337261f4d83200df1fcbe"}, - {file = "torch-2.2.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:6ab3ea2e29d1aac962e905142bbe50943758f55292f1b4fdfb6f4792aae3323e"}, - {file = "torch-2.2.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:d86664ec85902967d902e78272e97d1aff1d331f7619d398d3ffab1c9b8e9157"}, - {file = "torch-2.2.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d6227060f268894f92c61af0a44c0d8212e19cb98d05c20141c73312d923bc0a"}, - {file = "torch-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:77e990af75fb1675490deb374d36e726f84732cd5677d16f19124934b2409ce9"}, - {file = "torch-2.2.1-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:46085e328d9b738c261f470231e987930f4cc9472d9ffb7087c7a1343826ac51"}, - {file = "torch-2.2.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:2d9e7e5ecbb002257cf98fae13003abbd620196c35f85c9e34c2adfb961321ec"}, - {file = "torch-2.2.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:ada53aebede1c89570e56861b08d12ba4518a1f8b82d467c32665ec4d1f4b3c8"}, - {file = "torch-2.2.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:be21d4c41ecebed9e99430dac87de1439a8c7882faf23bba7fea3fea7b906ac1"}, - {file = "torch-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:79848f46196750367dcdf1d2132b722180b9d889571e14d579ae82d2f50596c5"}, - {file = "torch-2.2.1-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:7ee804847be6be0032fbd2d1e6742fea2814c92bebccb177f0d3b8e92b2d2b18"}, - {file = "torch-2.2.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:84b2fb322ab091039fdfe74e17442ff046b258eb5e513a28093152c5b07325a7"}, - {file = "torch-2.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5c0c83aa7d94569997f1f474595e808072d80b04d34912ce6f1a0e1c24b0c12a"}, - {file = "torch-2.2.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:91a1b598055ba06b2c386415d2e7f6ac818545e94c5def597a74754940188513"}, - {file = "torch-2.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:8f93ddf3001ecec16568390b507652644a3a103baa72de3ad3b9c530e3277098"}, - {file = "torch-2.2.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:0e8bdd4c77ac2584f33ee14c6cd3b12767b4da508ec4eed109520be7212d1069"}, - {file = "torch-2.2.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:6a21bcd7076677c97ca7db7506d683e4e9db137e8420eb4a68fb67c3668232a7"}, - {file = "torch-2.2.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:f1b90ac61f862634039265cd0f746cc9879feee03ff962c803486301b778714b"}, - {file = "torch-2.2.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:ed9e29eb94cd493b36bca9cb0b1fd7f06a0688215ad1e4b3ab4931726e0ec092"}, - {file = "torch-2.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:c47bc25744c743f3835831a20efdcfd60aeb7c3f9804a213f61e45803d16c2a5"}, - {file = "torch-2.2.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:0952549bcb43448c8d860d5e3e947dd18cbab491b14638e21750cb3090d5ad3e"}, - {file = "torch-2.2.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:26bd2272ec46fc62dcf7d24b2fb284d44fcb7be9d529ebf336b9860350d674ed"}, + {file = "torch-2.2.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bc889d311a855dd2dfd164daf8cc903a6b7273a747189cebafdd89106e4ad585"}, + {file = "torch-2.2.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:15dffa4cc3261fa73d02f0ed25f5fa49ecc9e12bf1ae0a4c1e7a88bbfaad9030"}, + {file = "torch-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:11e8fe261233aeabd67696d6b993eeb0896faa175c6b41b9a6c9f0334bdad1c5"}, + {file = "torch-2.2.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:b2e2200b245bd9f263a0d41b6a2dab69c4aca635a01b30cca78064b0ef5b109e"}, + {file = "torch-2.2.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:877b3e6593b5e00b35bbe111b7057464e76a7dd186a287280d941b564b0563c2"}, + {file = "torch-2.2.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:ad4c03b786e074f46606f4151c0a1e3740268bcf29fbd2fdf6666d66341c1dcb"}, + {file = "torch-2.2.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:32827fa1fbe5da8851686256b4cd94cc7b11be962862c2293811c94eea9457bf"}, + {file = "torch-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:f9ef0a648310435511e76905f9b89612e45ef2c8b023bee294f5e6f7e73a3e7c"}, + {file = "torch-2.2.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:95b9b44f3bcebd8b6cd8d37ec802048c872d9c567ba52c894bba90863a439059"}, + {file = "torch-2.2.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:49aa4126ede714c5aeef7ae92969b4b0bbe67f19665106463c39f22e0a1860d1"}, + {file = "torch-2.2.2-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:cf12cdb66c9c940227ad647bc9cf5dba7e8640772ae10dfe7569a0c1e2a28aca"}, + {file = "torch-2.2.2-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:89ddac2a8c1fb6569b90890955de0c34e1724f87431cacff4c1979b5f769203c"}, + {file = "torch-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:451331406b760f4b1ab298ddd536486ab3cfb1312614cfe0532133535be60bea"}, + {file = "torch-2.2.2-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:eb4d6e9d3663e26cd27dc3ad266b34445a16b54908e74725adb241aa56987533"}, + {file = "torch-2.2.2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:bf9558da7d2bf7463390b3b2a61a6a3dbb0b45b161ee1dd5ec640bf579d479fc"}, + {file = "torch-2.2.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cd2bf7697c9e95fb5d97cc1d525486d8cf11a084c6af1345c2c2c22a6b0029d0"}, + {file = "torch-2.2.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:b421448d194496e1114d87a8b8d6506bce949544e513742b097e2ab8f7efef32"}, + {file = "torch-2.2.2-cp38-cp38-win_amd64.whl", hash = "sha256:3dbcd563a9b792161640c0cffe17e3270d85e8f4243b1f1ed19cca43d28d235b"}, + {file = "torch-2.2.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:31f4310210e7dda49f1fb52b0ec9e59382cfcb938693f6d5378f25b43d7c1d29"}, + {file = "torch-2.2.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:c795feb7e8ce2e0ef63f75f8e1ab52e7fd5e1a4d7d0c31367ade1e3de35c9e95"}, + {file = "torch-2.2.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a6e5770d68158d07456bfcb5318b173886f579fdfbf747543901ce718ea94782"}, + {file = "torch-2.2.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:67dcd726edff108e2cd6c51ff0e416fd260c869904de95750e80051358680d24"}, + {file = "torch-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:539d5ef6c4ce15bd3bd47a7b4a6e7c10d49d4d21c0baaa87c7d2ef8698632dfb"}, + {file = "torch-2.2.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:dff696de90d6f6d1e8200e9892861fd4677306d0ef604cb18f2134186f719f82"}, + {file = "torch-2.2.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:3a4dd910663fd7a124c056c878a52c2b0be4a5a424188058fe97109d4436ee42"}, ] [package.dependencies] @@ -3317,78 +3435,44 @@ typing-extensions = ">=4.8.0" opt-einsum = ["opt-einsum (>=3.3)"] optree = ["optree (>=0.9.1)"] -[[package]] -name = "torchrl" -version = "0.4.0+13bef42" -description = "" -optional = false -python-versions = "*" -files = [] -develop = false - -[package.dependencies] -cloudpickle = "*" -numpy = "*" -packaging = "*" -tensordict = ">=0.4.0" -torch = ">=2.1.0" - -[package.extras] -all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"] -atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"] -checkpointing = ["torchsnapshot"] -dm-control = ["dm_control"] -gym-continuous = ["gymnasium", "mujoco"] -marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"] -offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"] -rendering = ["moviepy"] -tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"] -utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"] - -[package.source] -type = "git" -url = "https://github.com/pytorch/rl" -reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37" -resolved_reference = "13bef426dcfa5887c6e5034a6e9697993fa92c37" - [[package]] name = "torchvision" -version = "0.17.1" +version = "0.17.2" description = "image and video datasets and models for torch deep learning" optional = false python-versions = ">=3.8" files = [ - {file = "torchvision-0.17.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:06418880212b66e45e855dd39f536e7fd48b4e6b034a11dd9fe9e2384afb51ec"}, - {file = "torchvision-0.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:33d65d0c7fdcb3f7bc1dd8ed30ea3cd7e0587b4ad1b104b5677c8191a8bad9f1"}, - {file = "torchvision-0.17.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:aaefef2be6a02f206085ce4bb6c0078b03ebf48cb6ff82bd762ff6248475e08e"}, - {file = "torchvision-0.17.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ebe5fdb466aff8a8e8e755de84a843418b6f8d500624752c05eaa638d7700f3d"}, - {file = "torchvision-0.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:9d4d45a996f4313e9c5db4da71d31508d44f7ccfbf29d3442bdcc2ad13e0b6f3"}, - {file = "torchvision-0.17.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:ea2ccdbf5974e0bf27fd6644a33b19cb0700297cf397bb0469e762c11c6c4105"}, - {file = "torchvision-0.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9106e32c9f1e70afa8172cf1b064cf9c2998d8dff0769ec69d537b20209ee43d"}, - {file = "torchvision-0.17.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:5966936c669a08870f6547cd0a90d08b157aeda03293f79e2adbb934687175ed"}, - {file = "torchvision-0.17.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e74f5a26ef8190eab0c38b3f63914fea94e58e3b2f0e5466611c9f63bd91a80b"}, - {file = "torchvision-0.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:a2109c1a1dcf71e8940d43e91f78c4dd5bf0fcefb3a0a42244102752009f5862"}, - {file = "torchvision-0.17.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5d241d2a5fb4e608677fccf6f80b34a124446d324ee40c7814ce54bce888275b"}, - {file = "torchvision-0.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e0fe98d9d92c23d2262ff82f973242951b9357fb640f8888ac50848bd00f5b45"}, - {file = "torchvision-0.17.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:32dc5de86d2ade399e11087095674ca08a1649fb322cfe69336d28add467edcb"}, - {file = "torchvision-0.17.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:54902877410ffb5458ee52b6d0de4b25cf01496bee736d6825301a5f0398536e"}, - {file = "torchvision-0.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:cc22c1ed0f1aba3f98fd72b6f60021f57aec1d2f6af518522e8a0a83848de3a8"}, - {file = "torchvision-0.17.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:2621097065fa1c827885e2b52102e839a3541b933b7a90e0fa3c42c3de1bc3cf"}, - {file = "torchvision-0.17.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5ce76466af2b5a30573939cae1e6e62e29316ceb3ee748091002f312ab0912f6"}, - {file = "torchvision-0.17.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:bd5dcd14a32945c72f5c19341add94aa7c23dd7bca2bafde44d0f3c4344d17ed"}, - {file = "torchvision-0.17.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:dca22795cc02ca0d5ddc08c1422ff620bc9899f63d15dc36f71ef37250e17b75"}, - {file = "torchvision-0.17.1-cp38-cp38-win_amd64.whl", hash = "sha256:524405457dd97d9ab0e48df502f819d0f41a113ce8f00470bb9926d9d36efcf1"}, - {file = "torchvision-0.17.1-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:58299a724b37b893c7ce4d0b32ea1480c30e467cc114167964b45f6013f6c2d3"}, - {file = "torchvision-0.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8a1b17fb158b2b881f2c8796fe1839a624e49d5fd07aa61f6dae60ba4819421a"}, - {file = "torchvision-0.17.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:429d63eb7551aa4d8f6cdf08d109b5570c20cbcce36d9cb95b24556418e4dc82"}, - {file = "torchvision-0.17.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:0ecc9a58171bd555aed583bf2f72e7fd6cc4f767c14f8b80b6a8725eacf4ceb1"}, - {file = "torchvision-0.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:5f427ebee15521edcd836bfe05e86feb5189b5c943b9e3999ed0e3f391fbaa1d"}, + {file = "torchvision-0.17.2-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:1f2910fe3c21ad6875b2720d46fad835b2e4b336e9553d31ca364d24c90b1d4f"}, + {file = "torchvision-0.17.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ecc1c503fa8a54fbab777e06a7c228032b8ab78efebf35b28bc8f22f544f51f1"}, + {file = "torchvision-0.17.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:f400145fc108833e7c2fc28486a04989ca742146d7a2a2cc48878ebbb40cdbbd"}, + {file = "torchvision-0.17.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:e9e4bed404af33dfc92eecc2b513d21ddc4c242a7fd8708b3b09d3a26aa6f444"}, + {file = "torchvision-0.17.2-cp310-cp310-win_amd64.whl", hash = "sha256:ba2e62f233eab3d42b648c122a3a29c47cc108ca314dfd5cbb59cd3a143fd623"}, + {file = "torchvision-0.17.2-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:9b83e55ee7d0a1704f52b9c0ac87388e7a6d1d98a6bde7b0b35f9ab54d7bda54"}, + {file = "torchvision-0.17.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e031004a1bc432c980a7bd642f6c189a3efc316e423fc30b5569837166a4e28d"}, + {file = "torchvision-0.17.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:3bbc24b7713e8f22766992562547d8b4b10001208d372fe599255af84bfd1a69"}, + {file = "torchvision-0.17.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:833fd2e4216ced924c8aca0525733fe727f9a1af66dfad7c5be7257e97c39678"}, + {file = "torchvision-0.17.2-cp311-cp311-win_amd64.whl", hash = "sha256:6835897df852fad1015e6a106c167c83848114cbcc7d86112384a973404e4431"}, + {file = "torchvision-0.17.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:14fd1d4a033c325bdba2d03a69c3450cab6d3a625f85cc375781d9237ca5d04d"}, + {file = "torchvision-0.17.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9c3acbebbe379af112b62b535820174277b1f3eed30df264a4e458d58ee4e5b2"}, + {file = "torchvision-0.17.2-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:77d680adf6ce367166a186d2c7fda3a73807ab9a03b2c31a03fa8812c8c5335b"}, + {file = "torchvision-0.17.2-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:f1c9ab3152cfb27f83aca072cac93a3a4c4e4ab0261cf0f2d516b9868a4e96f3"}, + {file = "torchvision-0.17.2-cp312-cp312-win_amd64.whl", hash = "sha256:3f784381419f3ed3f2ec2aa42fb4aeec5bf4135e298d1631e41c926e6f1a0dff"}, + {file = "torchvision-0.17.2-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:b83aac8d78f48981146d582168d75b6c947cfb0a7693f76e219f1926f6e595a3"}, + {file = "torchvision-0.17.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1ece40557e122d79975860a005aa7e2a9e2e6c350a03e78a00ec1450083312fd"}, + {file = "torchvision-0.17.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:32dbeba3987e20f2dc1bce8d1504139fff582898346dfe8ad98d649f97ca78fa"}, + {file = "torchvision-0.17.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:35ba5c1600c3203549d2316422a659bd20c0cfda1b6085eec94fb9f35f55ca43"}, + {file = "torchvision-0.17.2-cp38-cp38-win_amd64.whl", hash = "sha256:2f69570f50b1d195e51bc03feffb7b7728207bc36efcfb1f0813712b2379d881"}, + {file = "torchvision-0.17.2-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:4868bbfa55758c8107e69a0e7dd5e77b89056035cd38b767ad5b98cdb71c0f0d"}, + {file = "torchvision-0.17.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:efd6d0dd0668e15d01a2cffadc74068433b32cbcf5692e0c4aa15fc5cb250ce7"}, + {file = "torchvision-0.17.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:7dc85b397f6c6d9ef12716ce0d6e11ac2b803f5cccff6fe3966db248e7774478"}, + {file = "torchvision-0.17.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d506854c5acd69b20a8b6641f01fe841685a21c5406b56813184f1c9fc94279e"}, + {file = "torchvision-0.17.2-cp39-cp39-win_amd64.whl", hash = "sha256:067095e87a020a7a251ac1d38483aa591c5ccb81e815527c54db88a982fc9267"}, ] [package.dependencies] numpy = "*" pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" -torch = "2.2.1" +torch = "2.2.2" [package.extras] scipy = ["scipy"] @@ -3438,13 +3522,13 @@ tutorials = ["matplotlib", "pandas", "tabulate", "torch"] [[package]] name = "typing-extensions" -version = "4.10.0" +version = "4.11.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, - {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, + {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"}, + {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"}, ] [[package]] @@ -3497,13 +3581,13 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [[package]] name = "wandb" -version = "0.16.4" +version = "0.16.6" description = "A CLI and library for interacting with the Weights & Biases API." optional = false python-versions = ">=3.7" files = [ - {file = "wandb-0.16.4-py3-none-any.whl", hash = "sha256:bb9eb5aa2c2c85e11c76040c4271366f54d4975167aa6320ba86c3f2d97fe5fa"}, - {file = "wandb-0.16.4.tar.gz", hash = "sha256:8752c67d1347a4c29777e64dc1e1a742a66c5ecde03aebadf2b0d62183fa307c"}, + {file = "wandb-0.16.6-py3-none-any.whl", hash = "sha256:5810019a3b981c796e98ea58557a7c380f18834e0c6bdaed15df115522e5616e"}, + {file = "wandb-0.16.6.tar.gz", hash = "sha256:86f491e3012d715e0d7d7421a4d6de41abef643b7403046261f962f3e512fe1c"}, ] [package.dependencies] @@ -3535,13 +3619,13 @@ sweeps = ["sweeps (>=0.2.0)"] [[package]] name = "werkzeug" -version = "3.0.1" +version = "3.0.2" description = "The comprehensive WSGI web application library." optional = false python-versions = ">=3.8" files = [ - {file = "werkzeug-3.0.1-py3-none-any.whl", hash = "sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10"}, - {file = "werkzeug-3.0.1.tar.gz", hash = "sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc"}, + {file = "werkzeug-3.0.2-py3-none-any.whl", hash = "sha256:3aac3f5da756f93030740bc235d3e09449efcf65f2f55e3602e1d851b8f48795"}, + {file = "werkzeug-3.0.2.tar.gz", hash = "sha256:e39b645a6ac92822588e7b39a692e7828724ceae0b0d702ef96701f90e70128d"}, ] [package.dependencies] @@ -3552,20 +3636,20 @@ watchdog = ["watchdog (>=2.3)"] [[package]] name = "zarr" -version = "2.17.1" +version = "2.17.2" description = "An implementation of chunked, compressed, N-dimensional arrays for Python" optional = false python-versions = ">=3.9" files = [ - {file = "zarr-2.17.1-py3-none-any.whl", hash = "sha256:e25df2741a6e92645f3890f30f3136d5b57a0f8f831094b024bbcab5f2797bc7"}, - {file = "zarr-2.17.1.tar.gz", hash = "sha256:564b3aa072122546fe69a0fa21736f466b20fad41754334b62619f088ce46261"}, + {file = "zarr-2.17.2-py3-none-any.whl", hash = "sha256:70d7cc07c24280c380ef80644151d136b7503b0d83c9f214e8000ddc0f57f69b"}, + {file = "zarr-2.17.2.tar.gz", hash = "sha256:2cbaa6cb4e342d45152d4a7a4b2013c337fcd3a8e7bc98253560180de60552ce"}, ] [package.dependencies] asciitree = "*" fasteners = {version = "*", markers = "sys_platform != \"emscripten\""} numcodecs = ">=0.10.0" -numpy = ">=1.21.1" +numpy = ">=1.23" [package.extras] docs = ["numcodecs[msgpack]", "numpydoc", "pydata-sphinx-theme", "sphinx", "sphinx-automodapi", "sphinx-copybutton", "sphinx-design", "sphinx-issues"] @@ -3586,7 +3670,12 @@ files = [ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +[extras] +aloha = ["gym-aloha"] +pusht = ["gym-pusht"] +xarm = ["gym-xarm"] + [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "174c7d42f8039eedd2c447a4e6cae5169782cbd94346b5606572a0010194ca05" +content-hash = "7ec0310f8dd0ffa4d92fa78e06513bce98c3657692b3753ff34aadd297a3766c" diff --git a/pyproject.toml b/pyproject.toml index 972c1b61..743dece8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,6 @@ packages = [{include = "lerobot"}] python = "^3.10" termcolor = "^2.4.0" omegaconf = "^2.3.0" -dm-env = "^1.6" pandas = "^2.2.1" wandb = "^0.16.3" moviepy = "^1.0.3" @@ -34,29 +33,40 @@ einops = "^0.7.0" pygame = "^2.5.2" pymunk = "^6.6.0" zarr = "^2.17.0" -shapely = "^2.0.3" -scikit-image = "^0.22.0" numba = "^0.59.0" mpmath = "^1.3.0" torch = "^2.2.1" -tensordict = {git = "https://github.com/pytorch/tensordict"} -torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"} -mujoco = "^2.3.7" opencv-python = "^4.9.0.80" diffusers = "^0.26.3" torchvision = "^0.17.1" h5py = "^3.10.0" -dm-control = "1.0.14" huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"} robomimic = "0.2.0" -gymnasium-robotics = "^1.2.4" gymnasium = "^0.29.1" cmake = "^3.29.0.1" +gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true} +gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true} +gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true} +# gym-pusht = { path = "../gym-pusht", develop = true, optional = true} +# gym-xarm = { path = "../gym-xarm", develop = true, optional = true} +# gym-aloha = { path = "../gym-aloha", develop = true, optional = true} + +[tool.poetry.extras] +pusht = ["gym-pusht"] +xarm = ["gym-xarm"] +aloha = ["gym-aloha"] + + +[tool.poetry.group.dev] +optional = true [tool.poetry.group.dev.dependencies] pre-commit = "^3.6.2" debugpy = "^1.8.1" + + +[tool.poetry.group.test.dependencies] pytest = "^8.1.0" pytest-cov = "^5.0.0" @@ -100,3 +110,6 @@ enable = true [build-system] requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"] build-backend = "poetry_dynamic_versioning.backend" + +[tool.black] +line-length = 110 diff --git a/sbatch.sh b/sbatch.sh deleted file mode 100644 index cb5b285a..00000000 --- a/sbatch.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash -#SBATCH --nodes=1 # total number of nodes (N to be defined) -#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU) -#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs) -#SBATCH --cpus-per-task=8 # number of cores per task (8x8 = 64 cores, or all the cores) -#SBATCH --time=2-00:00:00 -#SBATCH --output=/home/rcadene/slurm/%j.out -#SBATCH --error=/home/rcadene/slurm/%j.err -#SBATCH --qos=medium -#SBATCH --mail-user=re.cadene@gmail.com -#SBATCH --mail-type=ALL - -CMD=$@ -echo "command: $CMD" - -apptainer exec --nv \ -~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL - -source ~/.bashrc -#conda activate fowm -conda activate lerobot - -srun $CMD diff --git a/sbatch_hopper.sh b/sbatch_hopper.sh deleted file mode 100644 index cc410048..00000000 --- a/sbatch_hopper.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash -#SBATCH --nodes=1 # total number of nodes (N to be defined) -#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU) -#SBATCH --qos=normal # number of GPUs reserved per node (here 8, or all the GPUs) -#SBATCH --partition=hopper-prod -#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs) -#SBATCH --cpus-per-task=12 # number of cores per task -#SBATCH --mem-per-cpu=11G -#SBATCH --time=12:00:00 -#SBATCH --output=/admin/home/remi_cadene/slurm/%j.out -#SBATCH --error=/admin/home/remi_cadene/slurm/%j.err -#SBATCH --mail-user=remi_cadene@huggingface.co -#SBATCH --mail-type=ALL - -CMD=$@ -echo "command: $CMD" -srun $CMD diff --git a/tests/data/aloha_sim_insertion_human/data_dict.pth b/tests/data/aloha_sim_insertion_human/data_dict.pth new file mode 100644 index 00000000..1370c9ea Binary files /dev/null and b/tests/data/aloha_sim_insertion_human/data_dict.pth differ diff --git a/tests/data/aloha_sim_insertion_human/data_ids_per_episode.pth b/tests/data/aloha_sim_insertion_human/data_ids_per_episode.pth new file mode 100644 index 00000000..a1d481dd Binary files /dev/null and b/tests/data/aloha_sim_insertion_human/data_ids_per_episode.pth differ diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/action.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/action.memmap deleted file mode 100644 index f64b2989..00000000 --- a/tests/data/aloha_sim_insertion_human/replay_buffer/action.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d789deddb081a9f4b626342391de8f48949d38fb5fdead87b5c0737b46c0877a -size 2800 diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/episode.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/episode.memmap deleted file mode 100644 index af9fb07f..00000000 --- a/tests/data/aloha_sim_insertion_human/replay_buffer/episode.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5 -size 400 diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/frame_id.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/frame_id.memmap deleted file mode 100644 index dc2f585c..00000000 --- a/tests/data/aloha_sim_insertion_human/replay_buffer/frame_id.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585 -size 400 diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/meta.json deleted file mode 100644 index 2a0cf0a2..00000000 --- a/tests/data/aloha_sim_insertion_human/replay_buffer/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/next/done.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/next/done.memmap deleted file mode 100644 index 44fd709f..00000000 --- a/tests/data/aloha_sim_insertion_human/replay_buffer/next/done.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51 -size 50 diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/next/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/next/meta.json deleted file mode 100644 index 3bfa9bd7..00000000 --- a/tests/data/aloha_sim_insertion_human/replay_buffer/next/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/meta.json deleted file mode 100644 index cb29a5ab..00000000 --- a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/top.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/top.memmap deleted file mode 100644 index d3d8bd1c..00000000 --- a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/image/top.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5c632e3cb06be729e5d673e3ecca1d6f6527b0f48cfe3dc03d7eea4f9eb3bbd7 -size 46080000 diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/meta.json deleted file mode 100644 index 65ce1ca2..00000000 --- a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/state.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/state.memmap deleted file mode 100644 index 1f087a60..00000000 --- a/tests/data/aloha_sim_insertion_human/replay_buffer/next/observation/state.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e231f2e07e1cd030137ea2e938b570b112db2c694c6d21b37ceb8f8559e19088 -size 2800 diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/meta.json deleted file mode 100644 index cb29a5ab..00000000 --- a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/top.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/top.memmap deleted file mode 100644 index 00c0b783..00000000 --- a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/image/top.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a1ba64c89f4fcf9135fe34c26abf582dd5f0d573506db5c96af3ffe40a52c818 -size 46080000 diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/meta.json b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/meta.json deleted file mode 100644 index 65ce1ca2..00000000 --- a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/state.memmap b/tests/data/aloha_sim_insertion_human/replay_buffer/observation/state.memmap deleted file mode 100644 index a1131179..00000000 --- a/tests/data/aloha_sim_insertion_human/replay_buffer/observation/state.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:85405686bc065c6ab6c915907920a0391a57cf097b74de058a8c30be0548ade5 -size 2800 diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth index 87b18a24..a7b9248f 100644 Binary files a/tests/data/aloha_sim_insertion_human/stats.pth and b/tests/data/aloha_sim_insertion_human/stats.pth differ diff --git a/tests/data/aloha_sim_insertion_scripted/data_dict.pth b/tests/data/aloha_sim_insertion_scripted/data_dict.pth new file mode 100644 index 00000000..00c9f335 Binary files /dev/null and b/tests/data/aloha_sim_insertion_scripted/data_dict.pth differ diff --git a/tests/data/aloha_sim_insertion_scripted/data_ids_per_episode.pth b/tests/data/aloha_sim_insertion_scripted/data_ids_per_episode.pth new file mode 100644 index 00000000..a1d481dd Binary files /dev/null and b/tests/data/aloha_sim_insertion_scripted/data_ids_per_episode.pth differ diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/action.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/action.memmap deleted file mode 100644 index e4068b75..00000000 --- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/action.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:1f5fe053b760e8471885b82c10f4a6ea40874098036337ae5cc300c4775546be -size 2800 diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/episode.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/episode.memmap deleted file mode 100644 index af9fb07f..00000000 --- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/episode.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5 -size 400 diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/frame_id.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/frame_id.memmap deleted file mode 100644 index dc2f585c..00000000 --- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/frame_id.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585 -size 400 diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/meta.json deleted file mode 100644 index 2a0cf0a2..00000000 --- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/done.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/done.memmap deleted file mode 100644 index 44fd709f..00000000 --- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/done.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51 -size 50 diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/meta.json deleted file mode 100644 index 3bfa9bd7..00000000 --- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/meta.json deleted file mode 100644 index cb29a5ab..00000000 --- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/top.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/top.memmap deleted file mode 100644 index 83911729..00000000 --- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/image/top.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:daed2bb10498ba2557983d0d7e89399882fea7585e7ceff910e23c621bfdbf88 -size 46080000 diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/meta.json deleted file mode 100644 index 65ce1ca2..00000000 --- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/state.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/state.memmap deleted file mode 100644 index aef69da0..00000000 --- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/next/observation/state.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:bbad0302af70112ee312efe0eb0f44a2f1c8f6c5ef82ea4fb34625cdafbef057 -size 2800 diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/meta.json deleted file mode 100644 index cb29a5ab..00000000 --- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/top.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/top.memmap deleted file mode 100644 index f9f0a759..00000000 --- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/image/top.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:aba55ebb9dd004bf68444b9ebf024ed7713436099c06a0b8e541100ecbc69290 -size 46080000 diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/meta.json b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/meta.json deleted file mode 100644 index 65ce1ca2..00000000 --- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/state.memmap b/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/state.memmap deleted file mode 100644 index 91875055..00000000 --- a/tests/data/aloha_sim_insertion_scripted/replay_buffer/observation/state.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:dd4e7e14abf57561ca9839c910581266be90956e41bfb3bb21362ea0c321e77d -size 2800 diff --git a/tests/data/aloha_sim_insertion_scripted/stats.pth b/tests/data/aloha_sim_insertion_scripted/stats.pth index 7d149ca4..990d4647 100644 Binary files a/tests/data/aloha_sim_insertion_scripted/stats.pth and b/tests/data/aloha_sim_insertion_scripted/stats.pth differ diff --git a/tests/data/aloha_sim_transfer_cube_human/data_dict.pth b/tests/data/aloha_sim_transfer_cube_human/data_dict.pth new file mode 100644 index 00000000..ab851779 Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_human/data_dict.pth differ diff --git a/tests/data/aloha_sim_transfer_cube_human/data_ids_per_episode.pth b/tests/data/aloha_sim_transfer_cube_human/data_ids_per_episode.pth new file mode 100644 index 00000000..a1d481dd Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_human/data_ids_per_episode.pth differ diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/action.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/action.memmap deleted file mode 100644 index 9b4fef33..00000000 --- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/action.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:14fed0eed3d529a8ac0dd25a6d41585020772d02f9137fc9d604713b2f0f7076 -size 2800 diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/episode.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/episode.memmap deleted file mode 100644 index af9fb07f..00000000 --- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/episode.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5 -size 400 diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/frame_id.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/frame_id.memmap deleted file mode 100644 index dc2f585c..00000000 --- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/frame_id.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585 -size 400 diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/meta.json deleted file mode 100644 index 2a0cf0a2..00000000 --- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/done.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/done.memmap deleted file mode 100644 index 44fd709f..00000000 --- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/done.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51 -size 50 diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/meta.json deleted file mode 100644 index 3bfa9bd7..00000000 --- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/meta.json deleted file mode 100644 index cb29a5ab..00000000 --- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/top.memmap deleted file mode 100644 index cd2e7c06..00000000 --- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/image/top.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2f713ea7fc19e592ea409a5e0bdfde403e5b86f834cbabe3463b791e8437fafc -size 46080000 diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/meta.json deleted file mode 100644 index 65ce1ca2..00000000 --- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/state.memmap deleted file mode 100644 index 37feaad6..00000000 --- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/next/observation/state.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8c103e2c9d63c9f7cf9645bd24d9a2c4e8e08825dc75e230ebc793b8f9c213b0 -size 2800 diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/meta.json deleted file mode 100644 index cb29a5ab..00000000 --- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/top.memmap deleted file mode 100644 index 1188590c..00000000 --- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/image/top.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7dbf4aa01b184d0eaa21ea999078d7cff86e1ca484a109614176fdc49f1ee05c -size 46080000 diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/meta.json b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/meta.json deleted file mode 100644 index 65ce1ca2..00000000 --- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/state.memmap deleted file mode 100644 index 9ef4cfd6..00000000 --- a/tests/data/aloha_sim_transfer_cube_human/replay_buffer/observation/state.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4fa0b9c870d4615037b6fee9e9e85e54d84352e173f2c7c1035232272fe2a3dd -size 2800 diff --git a/tests/data/aloha_sim_transfer_cube_human/stats.pth b/tests/data/aloha_sim_transfer_cube_human/stats.pth index 22f3e4d9..1ae356e3 100644 Binary files a/tests/data/aloha_sim_transfer_cube_human/stats.pth and b/tests/data/aloha_sim_transfer_cube_human/stats.pth differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/data_dict.pth b/tests/data/aloha_sim_transfer_cube_scripted/data_dict.pth new file mode 100644 index 00000000..bd308bb0 Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_scripted/data_dict.pth differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/data_ids_per_episode.pth b/tests/data/aloha_sim_transfer_cube_scripted/data_ids_per_episode.pth new file mode 100644 index 00000000..a1d481dd Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_scripted/data_ids_per_episode.pth differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/action.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/action.memmap deleted file mode 100644 index 8ac0c726..00000000 --- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/action.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c0e199a82e2b7462e84406dbced5448a99f1dad9ce172771dfc3feb4b8597115 -size 2800 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/episode.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/episode.memmap deleted file mode 100644 index af9fb07f..00000000 --- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/episode.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5 -size 400 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/frame_id.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/frame_id.memmap deleted file mode 100644 index dc2f585c..00000000 --- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/frame_id.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585 -size 400 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/meta.json deleted file mode 100644 index 2a0cf0a2..00000000 --- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/done.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/done.memmap deleted file mode 100644 index 44fd709f..00000000 --- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/done.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51 -size 50 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/meta.json deleted file mode 100644 index 3bfa9bd7..00000000 --- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/meta.json deleted file mode 100644 index cb29a5ab..00000000 --- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/top.memmap deleted file mode 100644 index 8e5f533e..00000000 --- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/image/top.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:30b44d38cc4d68e06c716a875d39cbdbeacbfdc1657d6366f58c279efd27c52b -size 46080000 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/meta.json deleted file mode 100644 index 65ce1ca2..00000000 --- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/state.memmap deleted file mode 100644 index e88320d1..00000000 --- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/next/observation/state.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9f484e7ea4f5f612dd53ee2c0f7891b8f7b2168a54fc81941ac2f2447260c294 -size 2800 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/meta.json deleted file mode 100644 index cb29a5ab..00000000 --- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/top.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/top.memmap deleted file mode 100644 index d415da0a..00000000 --- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/image/top.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:09600206f56cc5b52dfb896204b0044c4e830da368da141d7bd10e52181f6835 -size 46080000 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/meta.json b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/meta.json deleted file mode 100644 index 65ce1ca2..00000000 --- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/state.memmap b/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/state.memmap deleted file mode 100644 index be3436fb..00000000 --- a/tests/data/aloha_sim_transfer_cube_scripted/replay_buffer/observation/state.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:60dcb547cf9a6372b78a455217a2408b6bece4371fba1df2a302b334d45c42a8 -size 2800 diff --git a/tests/data/aloha_sim_transfer_cube_scripted/stats.pth b/tests/data/aloha_sim_transfer_cube_scripted/stats.pth index 63465344..71547f09 100644 Binary files a/tests/data/aloha_sim_transfer_cube_scripted/stats.pth and b/tests/data/aloha_sim_transfer_cube_scripted/stats.pth differ diff --git a/tests/data/pusht/data_dict.pth b/tests/data/pusht/data_dict.pth new file mode 100644 index 00000000..a083c86c Binary files /dev/null and b/tests/data/pusht/data_dict.pth differ diff --git a/tests/data/pusht/data_ids_per_episode.pth b/tests/data/pusht/data_ids_per_episode.pth new file mode 100644 index 00000000..a1d481dd Binary files /dev/null and b/tests/data/pusht/data_ids_per_episode.pth differ diff --git a/tests/data/pusht/replay_buffer/action.memmap b/tests/data/pusht/replay_buffer/action.memmap deleted file mode 100644 index f4127fb1..00000000 --- a/tests/data/pusht/replay_buffer/action.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ba17d8e5c30151ea5f7f6fc31f19e12a68ce2113774b74c8aca0c7ef962a75f4 -size 400 diff --git a/tests/data/pusht/replay_buffer/episode.memmap b/tests/data/pusht/replay_buffer/episode.memmap deleted file mode 100644 index af9fb07f..00000000 --- a/tests/data/pusht/replay_buffer/episode.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5 -size 400 diff --git a/tests/data/pusht/replay_buffer/frame_id.memmap b/tests/data/pusht/replay_buffer/frame_id.memmap deleted file mode 100644 index dc2f585c..00000000 --- a/tests/data/pusht/replay_buffer/frame_id.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585 -size 400 diff --git a/tests/data/pusht/replay_buffer/meta.json b/tests/data/pusht/replay_buffer/meta.json deleted file mode 100644 index 6f7c4218..00000000 --- a/tests/data/pusht/replay_buffer/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"action": {"device": "cpu", "shape": [50, 2], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/pusht/replay_buffer/next/done.memmap b/tests/data/pusht/replay_buffer/next/done.memmap deleted file mode 100644 index 44fd709f..00000000 --- a/tests/data/pusht/replay_buffer/next/done.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51 -size 50 diff --git a/tests/data/pusht/replay_buffer/next/meta.json b/tests/data/pusht/replay_buffer/next/meta.json deleted file mode 100644 index b29a9ff7..00000000 --- a/tests/data/pusht/replay_buffer/next/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"reward": {"device": "cpu", "shape": [50, 1], "dtype": "torch.float32"}, "done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "success": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/pusht/replay_buffer/next/observation/image.memmap b/tests/data/pusht/replay_buffer/next/observation/image.memmap deleted file mode 100644 index 68634378..00000000 --- a/tests/data/pusht/replay_buffer/next/observation/image.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ff6a3748c8223a82e54c61442df7b8baf478a20497ee2353645a1e9ccd765162 -size 5529600 diff --git a/tests/data/pusht/replay_buffer/next/observation/meta.json b/tests/data/pusht/replay_buffer/next/observation/meta.json deleted file mode 100644 index 57e0edea..00000000 --- a/tests/data/pusht/replay_buffer/next/observation/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"image": {"device": "cpu", "shape": [50, 3, 96, 96], "dtype": "torch.float32"}, "state": {"device": "cpu", "shape": [50, 2], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/pusht/replay_buffer/next/observation/state.memmap b/tests/data/pusht/replay_buffer/next/observation/state.memmap deleted file mode 100644 index 8dd28f2a..00000000 --- a/tests/data/pusht/replay_buffer/next/observation/state.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:fad4ece6d5fd66bbafa34f6ff383c483410082b8d7d4f4616808c3c458ce1d43 -size 400 diff --git a/tests/data/pusht/replay_buffer/next/reward.memmap b/tests/data/pusht/replay_buffer/next/reward.memmap deleted file mode 100644 index 109ed5ad..00000000 --- a/tests/data/pusht/replay_buffer/next/reward.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6d9c54dee5660c46886f32d80e57e9dd0ffa57ee0cd2a762b036d9c8e0c3a33a -size 200 diff --git a/tests/data/pusht/replay_buffer/next/success.memmap b/tests/data/pusht/replay_buffer/next/success.memmap deleted file mode 100644 index 44fd709f..00000000 --- a/tests/data/pusht/replay_buffer/next/success.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51 -size 50 diff --git a/tests/data/pusht/replay_buffer/observation/image.memmap b/tests/data/pusht/replay_buffer/observation/image.memmap deleted file mode 100644 index 42c86ef0..00000000 --- a/tests/data/pusht/replay_buffer/observation/image.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4bbde5cfd8cff9fd9fc6c9a57177f6fd31c8a03cf853b7d2234312f38380b0ba -size 5529600 diff --git a/tests/data/pusht/replay_buffer/observation/meta.json b/tests/data/pusht/replay_buffer/observation/meta.json deleted file mode 100644 index 57e0edea..00000000 --- a/tests/data/pusht/replay_buffer/observation/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"image": {"device": "cpu", "shape": [50, 3, 96, 96], "dtype": "torch.float32"}, "state": {"device": "cpu", "shape": [50, 2], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/pusht/replay_buffer/observation/state.memmap b/tests/data/pusht/replay_buffer/observation/state.memmap deleted file mode 100644 index 3ac8e4ab..00000000 --- a/tests/data/pusht/replay_buffer/observation/state.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:67c7e39090a16546fb1eade833d704f26464d574d7e431415f828159a154d2bf -size 400 diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth index d7107185..636985fd 100644 Binary files a/tests/data/pusht/stats.pth and b/tests/data/pusht/stats.pth differ diff --git a/tests/data/xarm_lift_medium/data_dict.pth b/tests/data/xarm_lift_medium/data_dict.pth new file mode 100644 index 00000000..5c166576 Binary files /dev/null and b/tests/data/xarm_lift_medium/data_dict.pth differ diff --git a/tests/data/xarm_lift_medium/data_ids_per_episode.pth b/tests/data/xarm_lift_medium/data_ids_per_episode.pth new file mode 100644 index 00000000..21095017 Binary files /dev/null and b/tests/data/xarm_lift_medium/data_ids_per_episode.pth differ diff --git a/tests/data/xarm_lift_medium/replay_buffer/action.memmap b/tests/data/xarm_lift_medium/replay_buffer/action.memmap deleted file mode 100644 index c90afbe9..00000000 --- a/tests/data/xarm_lift_medium/replay_buffer/action.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:10ec2f944de18f1a2aa3fc2f9a4185c03e3a5afc31148c85c98b58602ac4186e -size 800 diff --git a/tests/data/xarm_lift_medium/replay_buffer/episode.memmap b/tests/data/xarm_lift_medium/replay_buffer/episode.memmap deleted file mode 100644 index 7924f028..00000000 --- a/tests/data/xarm_lift_medium/replay_buffer/episode.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:1a589cba6bf0dfce138110864b6509508a804d7ea5c519d0a3cd67c4a87fa2d0 -size 200 diff --git a/tests/data/xarm_lift_medium/replay_buffer/frame_id.memmap b/tests/data/xarm_lift_medium/replay_buffer/frame_id.memmap deleted file mode 100644 index a633d346..00000000 --- a/tests/data/xarm_lift_medium/replay_buffer/frame_id.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6afe7098f30bdc8564526517c085e62613f6cb67194153840567974cfa6f3815 -size 400 diff --git a/tests/data/xarm_lift_medium/replay_buffer/meta.json b/tests/data/xarm_lift_medium/replay_buffer/meta.json deleted file mode 100644 index 33dc932c..00000000 --- a/tests/data/xarm_lift_medium/replay_buffer/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"action": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int32"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/done.memmap b/tests/data/xarm_lift_medium/replay_buffer/next/done.memmap deleted file mode 100644 index cf5e9cca..00000000 --- a/tests/data/xarm_lift_medium/replay_buffer/next/done.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:dab3a9712c413c4bfcd91c645752ab981306b23d25bcd4da4c422911574673f7 -size 50 diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/meta.json b/tests/data/xarm_lift_medium/replay_buffer/next/meta.json deleted file mode 100644 index d69cadad..00000000 --- a/tests/data/xarm_lift_medium/replay_buffer/next/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"reward": {"device": "cpu", "shape": [50], "dtype": "torch.float32"}, "done": {"device": "cpu", "shape": [50], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/observation/image.memmap b/tests/data/xarm_lift_medium/replay_buffer/next/observation/image.memmap deleted file mode 100644 index 462d0117..00000000 --- a/tests/data/xarm_lift_medium/replay_buffer/next/observation/image.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d6f9d1422ce3764e7253f70ed4da278f0c07fafef0d5386479f09d6b6b9d8259 -size 1058400 diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/observation/meta.json b/tests/data/xarm_lift_medium/replay_buffer/next/observation/meta.json deleted file mode 100644 index b13b8ec9..00000000 --- a/tests/data/xarm_lift_medium/replay_buffer/next/observation/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"image": {"device": "cpu", "shape": [50, 3, 84, 84], "dtype": "torch.uint8"}, "state": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/observation/state.memmap b/tests/data/xarm_lift_medium/replay_buffer/next/observation/state.memmap deleted file mode 100644 index 1dbe6024..00000000 --- a/tests/data/xarm_lift_medium/replay_buffer/next/observation/state.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:52e7c1a3c4fb2423b195e66ffee2c9e23b3ea0ad7c8bfc4dec30a35c65cadcbb -size 800 diff --git a/tests/data/xarm_lift_medium/replay_buffer/next/reward.memmap b/tests/data/xarm_lift_medium/replay_buffer/next/reward.memmap deleted file mode 100644 index 9ff5d5a1..00000000 --- a/tests/data/xarm_lift_medium/replay_buffer/next/reward.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c4dbe8ea1966e5cc6da6daf5704805b9b5810f4575de7016b8f6cb1da1d7bb8a -size 200 diff --git a/tests/data/xarm_lift_medium/replay_buffer/observation/image.memmap b/tests/data/xarm_lift_medium/replay_buffer/observation/image.memmap deleted file mode 100644 index c9416940..00000000 --- a/tests/data/xarm_lift_medium/replay_buffer/observation/image.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8fca8ddbda3f7bb2f6e7553297c18f3ab8f8b73d64b5c9f81a3695ad9379d403 -size 1058400 diff --git a/tests/data/xarm_lift_medium/replay_buffer/observation/meta.json b/tests/data/xarm_lift_medium/replay_buffer/observation/meta.json deleted file mode 100644 index b13b8ec9..00000000 --- a/tests/data/xarm_lift_medium/replay_buffer/observation/meta.json +++ /dev/null @@ -1 +0,0 @@ -{"image": {"device": "cpu", "shape": [50, 3, 84, 84], "dtype": "torch.uint8"}, "state": {"device": "cpu", "shape": [50, 4], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": ""} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium/replay_buffer/observation/state.memmap b/tests/data/xarm_lift_medium/replay_buffer/observation/state.memmap deleted file mode 100644 index 3bae16df..00000000 --- a/tests/data/xarm_lift_medium/replay_buffer/observation/state.memmap +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7b3e3e12896d553c208ee152f6d447c877c435e15d010c4a6171966d5b8a0c0b -size 800 diff --git a/tests/data/xarm_lift_medium/stats.pth b/tests/data/xarm_lift_medium/stats.pth index 0accffb0..3ab4e05b 100644 Binary files a/tests/data/xarm_lift_medium/stats.pth and b/tests/data/xarm_lift_medium/stats.pth differ diff --git a/tests/scripts/mock_dataset.py b/tests/scripts/mock_dataset.py index d9c86464..72480666 100644 --- a/tests/scripts/mock_dataset.py +++ b/tests/scripts/mock_dataset.py @@ -18,28 +18,33 @@ Example: import argparse import shutil -from tensordict import TensorDict from pathlib import Path +import torch + def mock_dataset(in_data_dir, out_data_dir, num_frames): in_data_dir = Path(in_data_dir) out_data_dir = Path(out_data_dir) + out_data_dir.mkdir(exist_ok=True, parents=True) - # load full dataset as a tensor dict - in_td_data = TensorDict.load_memmap(in_data_dir / "replay_buffer") + # copy the first `n` frames for each data key so that we have real data + in_data_dict = torch.load(in_data_dir / "data_dict.pth") + out_data_dict = {key: in_data_dict[key][:num_frames].clone() for key in in_data_dict} + torch.save(out_data_dict, out_data_dir / "data_dict.pth") - # use 1 frame to know the specification of the dataset - # and copy it over `n` frames in the test artifact directory - out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir / "replay_buffer") + # recreate data_ids_per_episode that corresponds to the subset + episodes = in_data_dict["episode"][:num_frames].tolist() + data_ids_per_episode = {} + for idx, ep_id in enumerate(episodes): + if ep_id not in data_ids_per_episode: + data_ids_per_episode[ep_id] = [] + data_ids_per_episode[ep_id].append(idx) + for ep_id in data_ids_per_episode: + data_ids_per_episode[ep_id] = torch.tensor(data_ids_per_episode[ep_id]) + torch.save(data_ids_per_episode, out_data_dir / "data_ids_per_episode.pth") - # copy the first `n` frames so that we have real data - out_td_data[:num_frames] = in_td_data[:num_frames].clone() - - # make sure everything has been properly written - out_td_data.lock_() - - # copy the full statistics of dataset since it's pretty small + # copy the full statistics of dataset since it's small in_stats_path = in_data_dir / "stats.pth" out_stats_path = out_data_dir / "stats.pth" shutil.copy(in_stats_path, out_stats_path) diff --git a/tests/test_available.py b/tests/test_available.py index 9cc91efa..be74a42a 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -1,25 +1,20 @@ """ This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully -imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) corresponds. +imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid. -Note: - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - 1. set the required class attributes: - - for classes inheriting from `AbstractDataset`: `available_datasets` - - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - - for classes inheriting from `AbstractPolicy`: `name` - 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) - 3. update variables in `tests/test_available.py` by importing your new class +When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps: +- Set the required class attributes: `available_datasets`. +- Set the required class attributes: `name`. +- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) +- Update variables in `tests/test_available.py` by importing your new class """ +import importlib import pytest import lerobot +import gymnasium as gym -from lerobot.common.envs.aloha.env import AlohaEnv -from lerobot.common.envs.pusht.env import PushtEnv -from lerobot.common.envs.simxarm.env import SimxarmEnv - -from lerobot.common.datasets.simxarm import SimxarmDataset +from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.datasets.pusht import PushtDataset @@ -29,36 +24,30 @@ from lerobot.common.policies.tdmpc.policy import TDMPCPolicy def test_available(): - pol_classes = [ + policy_classes = [ ActionChunkingTransformerPolicy, DiffusionPolicy, TDMPCPolicy, ] - env_classes = [ - AlohaEnv, - PushtEnv, - SimxarmEnv, - ] - - dat_classes = [ - AlohaDataset, - PushtDataset, - SimxarmDataset, - ] + dataset_class_per_env = { + "aloha": AlohaDataset, + "pusht": PushtDataset, + "xarm": XarmDataset, + } - policies = [pol_cls.name for pol_cls in pol_classes] - assert set(policies) == set(lerobot.available_policies) + policies = [pol_cls.name for pol_cls in policy_classes] + assert set(policies) == set(lerobot.available_policies), policies - envs = [env_cls.name for env_cls in env_classes] - assert set(envs) == set(lerobot.available_envs) + for env_name in lerobot.available_envs: + for task_name in lerobot.available_tasks_per_env[env_name]: + package_name = f"gym_{env_name}" + importlib.import_module(package_name) + gym_handle = f"{package_name}/{task_name}" + assert gym_handle in gym.envs.registry.keys(), gym_handle - tasks_per_env = {env_cls.name: env_cls.available_tasks for env_cls in env_classes} - for env in envs: - assert set(tasks_per_env[env]) == set(lerobot.available_tasks_per_env[env]) - - datasets_per_env = {env_cls.name: dat_cls.available_datasets for env_cls, dat_cls in zip(env_classes, dat_classes)} - for env in envs: - assert set(datasets_per_env[env]) == set(lerobot.available_datasets_per_env[env]) + dataset_class = dataset_class_per_env[env_name] + available_datasets = lerobot.available_datasets_per_env[env_name] + assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}" diff --git a/tests/test_datasets.py b/tests/test_datasets.py index df41b03f..71eefa9c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,38 +1,91 @@ +import os +from pathlib import Path import einops import pytest import torch -from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, load_data_with_delta_timestamps +from lerobot.common.datasets.xarm import XarmDataset +from lerobot.common.transforms import Prod from lerobot.common.utils import init_hydra_config +import logging +from lerobot.common.datasets.factory import make_dataset from .utils import DEVICE, DEFAULT_CONFIG_PATH @pytest.mark.parametrize( - "env_name,dataset_id", + "env_name,dataset_id,policy_name", [ - ("simxarm", "lift"), - ("pusht", "pusht"), - ("aloha", "sim_insertion_human"), - ("aloha", "sim_insertion_scripted"), - ("aloha", "sim_transfer_cube_human"), - ("aloha", "sim_transfer_cube_scripted"), + ("xarm", "xarm_lift_medium", "tdmpc"), + ("pusht", "pusht", "diffusion"), + ("aloha", "aloha_sim_insertion_human", "act"), + ("aloha", "aloha_sim_insertion_scripted", "act"), + ("aloha", "aloha_sim_transfer_cube_human", "act"), + ("aloha", "aloha_sim_transfer_cube_scripted", "act"), ], ) -def test_factory(env_name, dataset_id): +def test_factory(env_name, dataset_id, policy_name): cfg = init_hydra_config( DEFAULT_CONFIG_PATH, - overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"] + overrides=[f"env={env_name}", f"dataset_id={dataset_id}", f"policy={policy_name}", f"device={DEVICE}"] ) - offline_buffer = make_offline_buffer(cfg) - for key in offline_buffer.image_keys: - img = offline_buffer[0].get(key) - assert img.dtype == torch.float32 - # TODO(rcadene): we assume for now that image normalization takes place in the model - assert img.max() <= 1.0 - assert img.min() >= 0.0 + dataset = make_dataset(cfg) + delta_timestamps = dataset.delta_timestamps + image_keys = dataset.image_keys + + item = dataset[0] + + keys_ndim_required = [ + ("action", 1, True), + ("episode", 0, True), + ("frame_id", 0, True), + ("timestamp", 0, True), + # TODO(rcadene): should we rename it agent_pos? + ("observation.state", 1, True), + ("next.reward", 0, False), + ("next.done", 0, False), + ] + + for key in image_keys: + keys_ndim_required.append( + (key, 3, True), + ) + assert dataset.data_dict[key].dtype == torch.uint8, f"{key}" + + # test number of dimensions + for key, ndim, required in keys_ndim_required: + if key not in item: + if required: + assert key in item, f"{key}" + else: + logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.') + continue + + if delta_timestamps is not None and key in delta_timestamps: + assert item[key].ndim == ndim + 1, f"{key}" + assert item[key].shape[0] == len(delta_timestamps[key]), f"{key}" + else: + assert item[key].ndim == ndim, f"{key}" + + if key in image_keys: + assert item[key].dtype == torch.float32, f"{key}" + # TODO(rcadene): we assume for now that image normalization takes place in the model + assert item[key].max() <= 1.0, f"{key}" + assert item[key].min() >= 0.0, f"{key}" + + if delta_timestamps is not None and key in delta_timestamps: + # test t,c,h,w + assert item[key].shape[1] == 3, f"{key}" + else: + # test c,h,w + assert item[key].shape[0] == 3, f"{key}" + + + if delta_timestamps is not None: + # test missing keys in delta_timestamps + for key in delta_timestamps: + assert key in item, f"{key}" def test_compute_stats(): @@ -41,26 +94,98 @@ def test_compute_stats(): We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do because we are working with a small dataset). """ - cfg = init_hydra_config( - DEFAULT_CONFIG_PATH, overrides=["env=aloha", "env.task=sim_transfer_cube_human"] + DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None + + # get transform to convert images from uint8 [0,255] to float32 [0,1] + transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0) + + dataset = XarmDataset( + dataset_id="xarm_lift_medium", + root=DATA_DIR, + transform=transform, ) - buffer = make_offline_buffer(cfg) - # Get all of the data. - all_data = TensorDictReplayBuffer( - storage=buffer._storage, - batch_size=len(buffer), - sampler=SamplerWithoutReplacement(), - ).sample().float() + # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched # computation of the statistics. While doing this, we also make sure it works when we don't divide the # dataset into even batches. - computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75)) - for k, pattern in buffer.stats_patterns.items(): - expected_mean = einops.reduce(all_data[k], pattern, "mean") - assert torch.allclose(computed_stats[k]["mean"], expected_mean) - assert torch.allclose( - computed_stats[k]["std"], - torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) - ) - assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min")) - assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max")) + computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25)) + + # get einops patterns to aggregate batches and compute statistics + stats_patterns = get_stats_einops_patterns(dataset) + + # get all frames from the dataset in the same dtype and range as during compute_stats + data_dict = transform(dataset.data_dict) + + # compute stats based on all frames from the dataset without any batching + expected_stats = {} + for k, pattern in stats_patterns.items(): + expected_stats[k] = {} + expected_stats[k]["mean"] = einops.reduce(data_dict[k], pattern, "mean") + expected_stats[k]["std"] = torch.sqrt(einops.reduce((data_dict[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")) + expected_stats[k]["min"] = einops.reduce(data_dict[k], pattern, "min") + expected_stats[k]["max"] = einops.reduce(data_dict[k], pattern, "max") + + # test computed stats match expected stats + for k in stats_patterns: + assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"]) + assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"]) + assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"]) + assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"]) + + # TODO(rcadene): check that the stats used for training are correct too + # # load stats that are expected to match the ones returned by computed_stats + # assert (dataset.data_dir / "stats.pth").exists() + # loaded_stats = torch.load(dataset.data_dir / "stats.pth") + + # # test loaded stats match expected stats + # for k in stats_patterns: + # assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"]) + # assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"]) + # assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"]) + # assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"]) + + +def test_load_data_with_delta_timestamps_within_tolerance(): + data_dict = { + "timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), + "index": torch.tensor([0, 1, 2, 3, 4]), + } + data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])} + delta_timestamps = {"index": [-0.2, 0, 0.139]} + key = "index" + current_ts = 0.3 + episode = 0 + tol = 0.04 + data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol) + assert not is_pad.any(), "Unexpected padding detected" + assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values" + +def test_load_data_with_delta_timestamps_outside_tolerance_inside_episode_range(): + data_dict = { + "timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), + "index": torch.tensor([0, 1, 2, 3, 4]), + } + data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])} + delta_timestamps = {"index": [-0.2, 0, 0.141]} + key = "index" + current_ts = 0.3 + episode = 0 + tol = 0.04 + with pytest.raises(AssertionError): + load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol) + +def test_load_data_with_delta_timestamps_outside_tolerance_outside_episode_range(): + data_dict = { + "timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), + "index": torch.tensor([0, 1, 2, 3, 4]), + } + data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])} + delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]} + key = "index" + current_ts = 0.3 + episode = 0 + tol = 0.04 + data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol) + assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values" + assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" + diff --git a/tests/test_envs.py b/tests/test_envs.py index eb3746db..d25231b0 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,112 +1,46 @@ +import importlib import pytest -from tensordict import TensorDict import torch -from torchrl.envs.utils import check_env_specs, step_mdp -from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.datasets.factory import make_dataset +import gymnasium as gym +from gymnasium.utils.env_checker import check_env -from lerobot.common.envs.aloha.env import AlohaEnv from lerobot.common.envs.factory import make_env -from lerobot.common.envs.pusht.env import PushtEnv -from lerobot.common.envs.simxarm.env import SimxarmEnv from lerobot.common.utils import init_hydra_config +from lerobot.common.envs.utils import preprocess_observation + from .utils import DEVICE, DEFAULT_CONFIG_PATH -def print_spec_rollout(env): - print("observation_spec:", env.observation_spec) - print("action_spec:", env.action_spec) - print("reward_spec:", env.reward_spec) - print("done_spec:", env.done_spec) - - td = env.reset() - print("reset tensordict", td) - - td = env.rand_step(td) - print("random step tensordict", td) - - def simple_rollout(steps=100): - # preallocate: - data = TensorDict({}, [steps]) - # reset - _data = env.reset() - for i in range(steps): - _data["action"] = env.action_spec.rand() - _data = env.step(_data) - data[i] = _data - _data = step_mdp(_data, keep_other=True) - return data - - print("data from rollout:", simple_rollout(100)) - - @pytest.mark.parametrize( - "task,from_pixels,pixels_only", + "env_name, task, obs_type", [ - ("sim_insertion", True, False), - ("sim_insertion", True, True), - ("sim_transfer_cube", True, False), - ("sim_transfer_cube", True, True), + # ("AlohaInsertion-v0", "state"), + ("aloha", "AlohaInsertion-v0", "pixels"), + ("aloha", "AlohaInsertion-v0", "pixels_agent_pos"), + ("aloha", "AlohaTransferCube-v0", "pixels"), + ("aloha", "AlohaTransferCube-v0", "pixels_agent_pos"), + ("xarm", "XarmLift-v0", "state"), + ("xarm", "XarmLift-v0", "pixels"), + ("xarm", "XarmLift-v0", "pixels_agent_pos"), + ("pusht", "PushT-v0", "state"), + ("pusht", "PushT-v0", "pixels"), + ("pusht", "PushT-v0", "pixels_agent_pos"), ], ) -def test_aloha(task, from_pixels, pixels_only): - env = AlohaEnv( - task, - from_pixels=from_pixels, - pixels_only=pixels_only, - image_size=[3, 480, 640] if from_pixels else None, - ) - # print_spec_rollout(env) - check_env_specs(env) - - -@pytest.mark.parametrize( - "task,from_pixels,pixels_only", - [ - ("lift", False, False), - ("lift", True, False), - ("lift", True, True), - # TODO(aliberts): Add simxarm other tasks - # ("reach", False, False), - # ("reach", True, False), - # ("push", False, False), - # ("push", True, False), - # ("peg_in_box", False, False), - # ("peg_in_box", True, False), - ], -) -def test_simxarm(task, from_pixels, pixels_only): - env = SimxarmEnv( - task, - from_pixels=from_pixels, - pixels_only=pixels_only, - image_size=84 if from_pixels else None, - ) - # print_spec_rollout(env) - check_env_specs(env) - - -@pytest.mark.parametrize( - "from_pixels,pixels_only", - [ - (True, False), - ], -) -def test_pusht(from_pixels, pixels_only): - env = PushtEnv( - from_pixels=from_pixels, - pixels_only=pixels_only, - image_size=96 if from_pixels else None, - ) - # print_spec_rollout(env) - check_env_specs(env) - +def test_env(env_name, task, obs_type): + package_name = f"gym_{env_name}" + importlib.import_module(package_name) + env = gym.make(f"{package_name}/{task}", obs_type=obs_type) + check_env(env.unwrapped, skip_render_check=True) + env.close() @pytest.mark.parametrize( "env_name", [ - "simxarm", "pusht", + "xarm", "aloha", ], ) @@ -116,18 +50,16 @@ def test_factory(env_name): overrides=[f"env={env_name}", f"device={DEVICE}"], ) - offline_buffer = make_offline_buffer(cfg) + dataset = make_dataset(cfg) - env = make_env(cfg) - for key in offline_buffer.image_keys: - assert env.reset().get(key).dtype == torch.uint8 - check_env_specs(env) - - env = make_env(cfg, transform=offline_buffer.transform) - for key in offline_buffer.image_keys: - img = env.reset().get(key) + env = make_env(cfg, num_parallel_envs=1) + obs, _ = env.reset() + obs = preprocess_observation(obs, transform=dataset.transform) + for key in dataset.image_keys: + img = obs[key] assert img.dtype == torch.float32 # TODO(rcadene): we assume for now that image normalization takes place in the model assert img.max() <= 1.0 assert img.min() >= 0.0 - check_env_specs(env) + + env.close() diff --git a/tests/test_policies.py b/tests/test_policies.py index 5d6b46d0..8ccc7c62 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,37 +1,35 @@ import pytest -from tensordict import TensorDict -from tensordict.nn import TensorDictModule import torch -from torchrl.data import UnboundedContinuousTensorSpec -from torchrl.envs import EnvBase +from lerobot.common.datasets.utils import cycle +from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.policies.factory import make_policy from lerobot.common.envs.factory import make_env -from lerobot.common.datasets.factory import make_offline_buffer -from lerobot.common.policies.abstract import AbstractPolicy +from lerobot.common.datasets.factory import make_dataset from lerobot.common.utils import init_hydra_config from .utils import DEVICE, DEFAULT_CONFIG_PATH @pytest.mark.parametrize( "env_name,policy_name,extra_overrides", [ - ("simxarm", "tdmpc", ["policy.mpc=true"]), + ("xarm", "tdmpc", ["policy.mpc=true"]), ("pusht", "tdmpc", ["policy.mpc=false"]), ("pusht", "diffusion", []), - ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]), - ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]), - ("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]), - ("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]), - # TODO(aliberts): simxarm not working with diffusion - # ("simxarm", "diffusion", []), + ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]), + ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]), + ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]), + ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]), + # TODO(aliberts): xarm not working with diffusion + # ("xarm", "diffusion", []), ], ) -def test_concrete_policy(env_name, policy_name, extra_overrides): +def test_policy(env_name, policy_name, extra_overrides): """ Tests: - Making the policy object. - Updating the policy. - Using the policy to select actions at inference time. + - Test the action can be applied to the policy """ cfg = init_hydra_config( DEFAULT_CONFIG_PATH, @@ -45,92 +43,44 @@ def test_concrete_policy(env_name, policy_name, extra_overrides): # Check that we can make the policy object. policy = make_policy(cfg) # Check that we run select_actions and get the appropriate output. - offline_buffer = make_offline_buffer(cfg) - env = make_env(cfg, transform=offline_buffer.transform) + dataset = make_dataset(cfg) + env = make_env(cfg, num_parallel_envs=2) - if env_name != "aloha": - # TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError: - # seq_length as a list is not supported for now. - policy.update(offline_buffer, torch.tensor(0, device=DEVICE)) - - action = policy( - env.observation_spec.rand()["observation"].to(DEVICE), - torch.tensor(0, device=DEVICE), + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=2, + shuffle=True, + pin_memory=DEVICE != "cpu", + drop_last=True, ) - assert action.shape == env.action_spec.shape + dl_iter = cycle(dataloader) + batch = next(dl_iter) -def test_abstract_policy_forward(): - """ - Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that: - - The policy is invoked the expected number of times during a rollout. - - The environment's termination condition is respected even when part way through an action trajectory. - - The observations are returned correctly. - """ + for key in batch: + batch[key] = batch[key].to(DEVICE, non_blocking=True) - n_action_steps = 8 # our test policy will output 8 action step horizons - terminate_at = 10 # some number that is more than n_action_steps but not a multiple - rollout_max_steps = terminate_at + 1 # some number greater than terminate_at + # Test updating the policy + policy(batch, step=0) - # A minimal environment for testing. - class StubEnv(EnvBase): + # reset the policy and environment + policy.reset() + observation, _ = env.reset(seed=cfg.seed) - def __init__(self): - super().__init__() - self.action_spec = UnboundedContinuousTensorSpec(shape=(1,)) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) + # apply transform to normalize the observations + observation = preprocess_observation(observation, dataset.transform) - def _step(self, tensordict: TensorDict) -> TensorDict: - self.invocation_count += 1 - return TensorDict( - { - "observation": torch.tensor([self.invocation_count]), - "reward": torch.tensor([self.invocation_count]), - "terminated": torch.tensor( - tensordict["action"].item() == terminate_at - ), - } - ) + # send observation to device/gpu + observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation} - def _reset(self, tensordict: TensorDict) -> TensorDict: - self.invocation_count = 0 - return TensorDict( - { - "observation": torch.tensor([self.invocation_count]), - "reward": torch.tensor([self.invocation_count]), - } - ) + # get the next action for the environment + with torch.inference_mode(): + action = policy.select_action(observation, step=0) - def _set_seed(self, seed: int | None): - return + # apply inverse transform to unnormalize the action + action = postprocess_action(action, dataset.transform) - class StubPolicy(AbstractPolicy): - name = "stub" + # Test step through policy + env.step(action) - def __init__(self): - super().__init__(n_action_steps) - self.n_policy_invocations = 0 - - def update(self): - pass - - def select_actions(self): - self.n_policy_invocations += 1 - return torch.stack( - [torch.tensor([i]) for i in range(self.n_action_steps)] - ).unsqueeze(0) - - env = StubEnv() - policy = StubPolicy() - policy = TensorDictModule( - policy, - in_keys=[], - out_keys=["action"], - ) - - # Keep track to make sure the policy is called the expected number of times - rollout = env.rollout(rollout_max_steps, policy) - - assert len(rollout) == terminate_at + 1 # +1 for the reset observation - assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1 - assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))