diff --git a/.dockerignore b/.dockerignore
index b8c1be15..c0d8a84b 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
# Misc
.git
tmp
@@ -59,7 +73,7 @@ pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
-!tests/data
+!tests/artifacts
htmlcov/
.tox/
.nox/
diff --git a/.gitattributes b/.gitattributes
index 7da36424..44e16cf1 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
*.memmap filter=lfs diff=lfs merge=lfs -text
*.stl filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml
index 7cbed673..2fb23051 100644
--- a/.github/ISSUE_TEMPLATE/bug-report.yml
+++ b/.github/ISSUE_TEMPLATE/bug-report.yml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve LeRobot
body:
diff --git a/.github/workflows/build-docker-images.yml b/.github/workflows/build-docker-images.yml
index 3c63fa11..0cb11d57 100644
--- a/.github/workflows/build-docker-images.yml
+++ b/.github/workflows/build-docker-images.yml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
# Inspired by
# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml
name: Builds
diff --git a/.github/workflows/nightly-tests.yml b/.github/workflows/nightly-tests.yml
index 210a690c..adac9f20 100644
--- a/.github/workflows/nightly-tests.yml
+++ b/.github/workflows/nightly-tests.yml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
# Inspired by
# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml
name: Nightly
diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml
index b42539e6..332b543c 100644
--- a/.github/workflows/quality.yml
+++ b/.github/workflows/quality.yml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
name: Quality
on:
@@ -32,13 +46,27 @@ jobs:
id: get-ruff-version
run: |
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
- echo "RUFF_VERSION=${RUFF_VERSION}" >> $GITHUB_ENV
+ echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
- name: Install Ruff
- run: python -m pip install "ruff==${{ env.RUFF_VERSION }}"
+ env:
+ RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
+ run: python -m pip install "ruff==${RUFF_VERSION}"
- name: Ruff check
run: ruff check --output-format=github
- name: Ruff format
run: ruff format --diff
+
+ typos:
+ name: Typos
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout Repository
+ uses: actions/checkout@v4
+ with:
+ persist-credentials: false
+
+ - name: typos-action
+ uses: crate-ci/typos@v1.29.10
diff --git a/.github/workflows/test-docker-build.yml b/.github/workflows/test-docker-build.yml
index 4d6e9ce5..c3102564 100644
--- a/.github/workflows/test-docker-build.yml
+++ b/.github/workflows/test-docker-build.yml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
# Inspired by
# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml
name: Test Dockerfiles
@@ -27,7 +41,7 @@ jobs:
- name: Get changed files
id: changed-files
- uses: tj-actions/changed-files@v44
+ uses: tj-actions/changed-files@3f54ebb830831fc121d3263c1857cfbdc310cdb9 #v42
with:
files: docker/**
json: "true"
@@ -43,7 +57,7 @@ jobs:
needs: get_changed_files
runs-on:
group: aws-general-8-plus
- if: ${{ needs.get_changed_files.outputs.matrix }} != ''
+ if: needs.get_changed_files.outputs.matrix != ''
strategy:
fail-fast: false
matrix:
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 9c3f5756..d91c5364 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
name: Tests
on:
@@ -112,7 +126,7 @@ jobs:
# portaudio19-dev is needed to install pyaudio
run: |
sudo apt-get update && \
- sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
+ sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
- name: Install uv and python
uses: astral-sh/setup-uv@v5
diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml
index 487ccea5..166e0590 100644
--- a/.github/workflows/trufflehog.yml
+++ b/.github/workflows/trufflehog.yml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
on:
push:
diff --git a/.gitignore b/.gitignore
index 0a0ffe10..d6c51c90 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
# Logging
logs
tmp
@@ -64,7 +78,7 @@ pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
-!tests/data
+!tests/artifacts
htmlcov/
.tox/
.nox/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 00b538e8..e699f543 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,7 +1,29 @@
-exclude: ^(tests/data)
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+exclude: "tests/artifacts/.*\\.safetensors$"
default_language_version:
python: python3.10
repos:
+ ##### Meta #####
+ - repo: meta
+ hooks:
+ - id: check-useless-excludes
+ - id: check-hooks-apply
+
+
+ ##### Style / Misc. #####
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
@@ -13,21 +35,40 @@ repos:
- id: check-toml
- id: end-of-file-fixer
- id: trailing-whitespace
+
+ - repo: https://github.com/crate-ci/typos
+ rev: v1.30.2
+ hooks:
+ - id: typos
+ args: [--force-exclude]
+
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.1
hooks:
- id: pyupgrade
+
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.9.6
+ rev: v0.9.10
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
+
+
+ ##### Security #####
- repo: https://github.com/gitleaks/gitleaks
- rev: v8.23.3
+ rev: v8.24.0
hooks:
- id: gitleaks
+
- repo: https://github.com/woodruffw/zizmor-pre-commit
- rev: v1.3.1
+ rev: v1.4.1
hooks:
- id: zizmor
+
+ - repo: https://github.com/PyCQA/bandit
+ rev: 1.8.3
+ hooks:
+ - id: bandit
+ args: ["-c", "pyproject.toml"]
+ additional_dependencies: ["bandit[toml]"]
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 61fa2eb9..a9e4a856 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -228,7 +228,7 @@ Follow these steps to start contributing:
git commit
```
- Note, if you already commited some changes that have a wrong formatting, you can use:
+ Note, if you already committed some changes that have a wrong formatting, you can use:
```bash
pre-commit run --all-files
```
@@ -291,7 +291,7 @@ sudo apt-get install git-lfs
git lfs install
```
-Pull artifacts if they're not in [tests/data](tests/data)
+Pull artifacts if they're not in [tests/artifacts](tests/artifacts)
```bash
git lfs pull
```
diff --git a/Makefile b/Makefile
index 772da320..c82483cc 100644
--- a/Makefile
+++ b/Makefile
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
.PHONY: tests
PYTHON_PATH := $(shell which python)
@@ -33,6 +47,7 @@ test-act-ete-train:
--policy.dim_model=64 \
--policy.n_action_steps=20 \
--policy.chunk_size=20 \
+ --policy.device=$(DEVICE) \
--env.type=aloha \
--env.episode_length=5 \
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
@@ -47,7 +62,6 @@ test-act-ete-train:
--save_checkpoint=true \
--log_freq=1 \
--wandb.enable=false \
- --device=$(DEVICE) \
--output_dir=tests/outputs/act/
test-act-ete-train-resume:
@@ -58,11 +72,11 @@ test-act-ete-train-resume:
test-act-ete-eval:
python lerobot/scripts/eval.py \
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
+ --policy.device=$(DEVICE) \
--env.type=aloha \
--env.episode_length=5 \
--eval.n_episodes=1 \
- --eval.batch_size=1 \
- --device=$(DEVICE)
+ --eval.batch_size=1
test-diffusion-ete-train:
python lerobot/scripts/train.py \
@@ -70,6 +84,7 @@ test-diffusion-ete-train:
--policy.down_dims='[64,128,256]' \
--policy.diffusion_step_embed_dim=32 \
--policy.num_inference_steps=10 \
+ --policy.device=$(DEVICE) \
--env.type=pusht \
--env.episode_length=5 \
--dataset.repo_id=lerobot/pusht \
@@ -84,21 +99,21 @@ test-diffusion-ete-train:
--save_freq=2 \
--log_freq=1 \
--wandb.enable=false \
- --device=$(DEVICE) \
--output_dir=tests/outputs/diffusion/
test-diffusion-ete-eval:
python lerobot/scripts/eval.py \
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
+ --policy.device=$(DEVICE) \
--env.type=pusht \
--env.episode_length=5 \
--eval.n_episodes=1 \
- --eval.batch_size=1 \
- --device=$(DEVICE)
+ --eval.batch_size=1
test-tdmpc-ete-train:
python lerobot/scripts/train.py \
--policy.type=tdmpc \
+ --policy.device=$(DEVICE) \
--env.type=xarm \
--env.task=XarmLift-v0 \
--env.episode_length=5 \
@@ -114,15 +129,14 @@ test-tdmpc-ete-train:
--save_freq=2 \
--log_freq=1 \
--wandb.enable=false \
- --device=$(DEVICE) \
--output_dir=tests/outputs/tdmpc/
test-tdmpc-ete-eval:
python lerobot/scripts/eval.py \
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
+ --policy.device=$(DEVICE) \
--env.type=xarm \
--env.episode_length=5 \
--env.task=XarmLift-v0 \
--eval.n_episodes=1 \
- --eval.batch_size=1 \
- --device=$(DEVICE)
+ --eval.batch_size=1
diff --git a/README.md b/README.md
index 5125ace5..4483940d 100644
--- a/README.md
+++ b/README.md
@@ -23,15 +23,24 @@
Want to take it to the next level? Make your SO-100 mobile by building LeKiwi!
+
Check out the LeKiwi tutorial and bring your robot to life on wheels.
+
+
@@ -89,14 +98,18 @@ conda create -y -n lerobot python=3.10
conda activate lerobot
```
-Install 🤗 LeRobot:
+When using `miniconda`, if you don't have `ffmpeg` in your environment:
```bash
-pip install -e .
+conda install ffmpeg
```
-> **NOTE:** Depending on your platform, If you encounter any build errors during this step
-you may need to install `cmake` and `build-essential` for building some of our dependencies.
-On linux: `sudo apt-get install cmake build-essential`
+Install 🤗 LeRobot:
+```bash
+pip install --no-binary=av -e .
+```
+
+> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
+`sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
- [aloha](https://github.com/huggingface/gym-aloha)
@@ -105,7 +118,7 @@ For simulations, 🤗 LeRobot comes with gymnasium environments that can be inst
For instance, to install 🤗 LeRobot with aloha and pusht, use:
```bash
-pip install -e ".[aloha, pusht]"
+pip install --no-binary=av -e ".[aloha, pusht]"
```
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
@@ -210,7 +223,7 @@ A `LeRobotDataset` is serialised using several widespread file formats for each
- videos are stored in mp4 format to save space
- metadata are stored in plain json/jsonl files
-Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can use the `local_files_only` argument and specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
+Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
### Evaluate a pretrained policy
@@ -223,8 +236,8 @@ python lerobot/scripts/eval.py \
--env.type=pusht \
--eval.batch_size=10 \
--eval.n_episodes=10 \
- --use_amp=false \
- --device=cuda
+ --policy.use_amp=false \
+ --policy.device=cuda
```
Note: After training your own policy, you can re-evaluate the checkpoints with:
@@ -375,3 +388,6 @@ Additionally, if you are using any of the particular policy architecture, pretra
year={2024}
}
```
+## Star History
+
+[](https://star-history.com/#huggingface/lerobot&Timeline)
diff --git a/benchmarks/video/README.md b/benchmarks/video/README.md
index 56cd1d1e..daa3e1f4 100644
--- a/benchmarks/video/README.md
+++ b/benchmarks/video/README.md
@@ -51,7 +51,7 @@ For a comprehensive list and documentation of these parameters, see the ffmpeg d
### Decoding parameters
**Decoder**
We tested two video decoding backends from torchvision:
-- `pyav` (default)
+- `pyav`
- `video_reader` (requires to build torchvision from source)
**Requested timestamps**
@@ -114,7 +114,7 @@ We tried to measure the most impactful parameters for both encoding and decoding
Additional encoding parameters exist that are not included in this benchmark. In particular:
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
-- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
+- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
diff --git a/benchmarks/video/run_video_benchmark.py b/benchmarks/video/run_video_benchmark.py
index e9066487..c62578c4 100644
--- a/benchmarks/video/run_video_benchmark.py
+++ b/benchmarks/video/run_video_benchmark.py
@@ -67,7 +67,7 @@ def parse_int_or_none(value) -> int | None:
def check_datasets_formats(repo_ids: list) -> None:
for repo_id in repo_ids:
dataset = LeRobotDataset(repo_id)
- if dataset.video:
+ if len(dataset.meta.video_keys) > 0:
raise ValueError(
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
)
diff --git a/docker/lerobot-cpu/Dockerfile b/docker/lerobot-cpu/Dockerfile
index 06673092..13a45d24 100644
--- a/docker/lerobot-cpu/Dockerfile
+++ b/docker/lerobot-cpu/Dockerfile
@@ -1,33 +1,29 @@
# Configure image
ARG PYTHON_VERSION=3.10
-
FROM python:${PYTHON_VERSION}-slim
-ARG PYTHON_VERSION
-ARG DEBIAN_FRONTEND=noninteractive
-# Install apt dependencies
+# Configure environment variables
+ARG PYTHON_VERSION
+ENV DEBIAN_FRONTEND=noninteractive
+ENV MUJOCO_GL="egl"
+ENV PATH="/opt/venv/bin:$PATH"
+
+# Install dependencies and set up Python in a single layer
RUN apt-get update && apt-get install -y --no-install-recommends \
- build-essential cmake git git-lfs \
+ build-essential cmake git \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher libgeos-dev \
- && apt-get clean && rm -rf /var/lib/apt/lists/*
+ && ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
+ && python -m venv /opt/venv \
+ && apt-get clean && rm -rf /var/lib/apt/lists/* \
+ && echo "source /opt/venv/bin/activate" >> /root/.bashrc
-# Create virtual environment
-RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
-RUN python -m venv /opt/venv
-ENV PATH="/opt/venv/bin:$PATH"
-RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
-
-# Install LeRobot
-RUN git lfs install
-RUN git clone https://github.com/huggingface/lerobot.git /lerobot
+# Clone repository and install LeRobot in a single layer
+COPY . /lerobot
WORKDIR /lerobot
-RUN pip install --upgrade --no-cache-dir pip
-RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
- --extra-index-url https://download.pytorch.org/whl/cpu
-
-# Set EGL as the rendering backend for MuJoCo
-ENV MUJOCO_GL="egl"
+RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
+ && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
+ --extra-index-url https://download.pytorch.org/whl/cpu
# Execute in bash shell rather than python
CMD ["/bin/bash"]
diff --git a/docker/lerobot-gpu/Dockerfile b/docker/lerobot-gpu/Dockerfile
index 65ca4377..642a8ded 100644
--- a/docker/lerobot-gpu/Dockerfile
+++ b/docker/lerobot-gpu/Dockerfile
@@ -1,31 +1,24 @@
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
-# Configure image
+# Configure environment variables
ARG PYTHON_VERSION=3.10
-ARG DEBIAN_FRONTEND=noninteractive
+ENV DEBIAN_FRONTEND=noninteractive
+ENV MUJOCO_GL="egl"
+ENV PATH="/opt/venv/bin:$PATH"
-
-# Install apt dependencies
+# Install dependencies and set up Python in a single layer
RUN apt-get update && apt-get install -y --no-install-recommends \
- build-essential cmake git git-lfs \
+ build-essential cmake git \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher libgeos-dev \
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
- && apt-get clean && rm -rf /var/lib/apt/lists/*
+ && ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
+ && python -m venv /opt/venv \
+ && apt-get clean && rm -rf /var/lib/apt/lists/* \
+ && echo "source /opt/venv/bin/activate" >> /root/.bashrc
-
-# Create virtual environment
-RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
-RUN python -m venv /opt/venv
-ENV PATH="/opt/venv/bin:$PATH"
-RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
-
-# Install LeRobot
-RUN git lfs install
-RUN git clone https://github.com/huggingface/lerobot.git /lerobot
+# Clone repository and install LeRobot in a single layer
+COPY . /lerobot
WORKDIR /lerobot
-RUN pip install --upgrade --no-cache-dir pip
-RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"
-
-# Set EGL as the rendering backend for MuJoCo
-ENV MUJOCO_GL="egl"
+RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
+ && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"
diff --git a/examples/10_use_so100.md b/examples/10_use_so100.md
index f7efcb45..8fb6d3b5 100644
--- a/examples/10_use_so100.md
+++ b/examples/10_use_so100.md
@@ -4,8 +4,8 @@
- [A. Source the parts](#a-source-the-parts)
- [B. Install LeRobot](#b-install-lerobot)
- - [C. Configure the motors](#c-configure-the-motors)
- - [D. Assemble the arms](#d-assemble-the-arms)
+ - [C. Configure the Motors](#c-configure-the-motors)
+ - [D. Step-by-Step Assembly Instructions](#d-step-by-step-assembly-instructions)
- [E. Calibrate](#e-calibrate)
- [F. Teleoperate](#f-teleoperate)
- [G. Record a dataset](#g-record-a-dataset)
@@ -59,17 +59,12 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
#### 5. Install LeRobot with dependencies for the feetech motors:
```bash
-cd ~/lerobot && pip install -e ".[feetech]"
+cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
```
-*EXTRA: For Linux only (not Mac)*: install extra dependencies for recording datasets:
-```bash
-conda install -y -c conda-forge ffmpeg
-pip uninstall -y opencv-python
-conda install -y -c conda-forge "opencv>=4.10.0"
-```
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
+
## C. Configure the motors
> [!NOTE]
@@ -98,22 +93,22 @@ Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem5
```
Finding all available ports for the MotorBus.
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
-Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
+Remove the usb cable from your MotorsBus and press Enter when done.
[...Disconnect leader arm and press Enter...]
-The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
+The port of this MotorsBus is /dev/tty.usbmodem575E0031751
Reconnect the usb cable.
```
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
```
Finding all available ports for the MotorBus.
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
-Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
+Remove the usb cable from your MotorsBus and press Enter when done.
[...Disconnect follower arm and press Enter...]
-The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
+The port of this MotorsBus is /dev/tty.usbmodem575E0032081
Reconnect the usb cable.
```
@@ -221,19 +216,13 @@ Redo the process for all your motors until ID 6. Do the same for the 6 motors of
Follow the video for removing gears. You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
-#### c. Add motor horn to all 12 motors
+## D. Step-by-Step Assembly Instructions
-
-Video adding motor horn
+**Step 1: Clean Parts**
+- Remove all support material from the 3D-printed parts.
+---
-
-
-
-
-Follow the video for adding the motor horn. For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
-Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
-
-## D. Assemble the arms
+### Additional Guidance
Video assembling arms
@@ -242,7 +231,211 @@ Try to avoid rotating the motor while doing so to keep position 2048 set during
-Follow the video for assembling the arms. It is important to insert the cables into the motor that is being assembled before you assemble the motor into the arm! Inserting the cables beforehand is much easier than doing this afterward. The first arm should take a bit more than 1 hour to assemble, but once you get used to it, you can do it under 1 hour for the second arm.
+**Note:**
+This video provides visual guidance for assembling the arms, but it doesn't specify when or how to do the wiring. Inserting the cables beforehand is much easier than doing it afterward. The first arm may take a bit more than 1 hour to assemble, but once you get used to it, you can assemble the second arm in under 1 hour.
+
+---
+
+### First Motor
+
+**Step 2: Insert Wires**
+- Insert two wires into the first motor.
+
+
+
+**Step 3: Install in Base**
+- Place the first motor into the base.
+
+
+
+**Step 4: Secure Motor**
+- Fasten the motor with 4 screws. Two from the bottom and two from top.
+
+**Step 5: Attach Motor Holder**
+- Slide over the first motor holder and fasten it using two screws (one on each side).
+
+
+
+**Step 6: Attach Motor Horns**
+- Install both motor horns, securing the top horn with a screw. Try not to move the motor position when attaching the motor horn, especially for the leader arms, where we removed the gears.
+
+
+
+ Video adding motor horn
+
+
+
+**Step 7: Attach Shoulder Part**
+- Route one wire to the back of the robot and the other to the left or in photo towards you (see photo).
+- Attach the shoulder part.
+
+
+
+**Step 8: Secure Shoulder**
+- Tighten the shoulder part with 4 screws on top and 4 on the bottom
+*(access bottom holes by turning the shoulder).*
+
+---
+
+### Second Motor Assembly
+
+**Step 9: Install Motor 2**
+- Slide the second motor in from the top and link the wire from motor 1 to motor 2.
+
+
+
+**Step 10: Attach Shoulder Holder**
+- Add the shoulder motor holder.
+- Ensure the wire from motor 1 to motor 2 goes behind the holder while the other wire is routed upward (see photo).
+- This part can be tight to assemble, you can use a workbench like the image or a similar setup to push the part around the motor.
+
+
+
+
+
+
+
+**Step 11: Secure Motor 2**
+- Fasten the second motor with 4 screws.
+
+**Step 12: Attach Motor Horn**
+- Attach both motor horns to motor 2, again use the horn screw.
+
+**Step 13: Attach Base**
+- Install the base attachment using 2 screws.
+
+
+
+**Step 14: Attach Upper Arm**
+- Attach the upper arm with 4 screws on each side.
+
+
+
+---
+
+### Third Motor Assembly
+
+**Step 15: Install Motor 3**
+- Route the motor cable from motor 2 through the cable holder to motor 3, then secure motor 3 with 4 screws.
+
+**Step 16: Attach Motor Horn**
+- Attach both motor horns to motor 3 and secure one again with a horn screw.
+
+
+
+**Step 17: Attach Forearm**
+- Connect the forearm to motor 3 using 4 screws on each side.
+
+
+
+---
+
+### Fourth Motor Assembly
+
+**Step 18: Install Motor 4**
+- Slide in motor 4, attach the cable from motor 3, and secure the cable in its holder with a screw.
+
+
+
+
+
+
+**Step 19: Attach Motor Holder 4**
+- Install the fourth motor holder (a tight fit). Ensure one wire is routed upward and the wire from motor 3 is routed downward (see photo).
+
+
+
+**Step 20: Secure Motor 4 & Attach Horn**
+- Fasten motor 4 with 4 screws and attach its motor horns, use for one a horn screw.
+
+
+
+---
+
+### Wrist Assembly
+
+**Step 21: Install Motor 5**
+- Insert motor 5 into the wrist holder and secure it with 2 front screws.
+
+
+
+**Step 22: Attach Wrist**
+- Connect the wire from motor 4 to motor 5. And already insert the other wire for the gripper.
+- Secure the wrist to motor 4 using 4 screws on both sides.
+
+
+
+**Step 23: Attach Wrist Horn**
+- Install only one motor horn on the wrist motor and secure it with a horn screw.
+
+
+
+---
+
+### Follower Configuration
+
+**Step 24: Attach Gripper**
+- Attach the gripper to motor 5.
+
+
+
+**Step 25: Install Gripper Motor**
+- Insert the gripper motor, connect the motor wire from motor 5 to motor 6, and secure it with 3 screws on each side.
+
+
+
+**Step 26: Attach Gripper Horn & Claw**
+- Attach the motor horns and again use a horn screw.
+- Install the gripper claw and secure it with 4 screws on both sides.
+
+
+
+**Step 27: Mount Controller**
+- Attach the motor controller on the back.
+
+
+
+
+
+
+*Assembly complete – proceed to Leader arm assembly.*
+
+---
+
+### Leader Configuration
+
+For the leader configuration, perform **Steps 1–23**. Make sure that you removed the motor gears from the motors.
+
+**Step 24: Attach Leader Holder**
+- Mount the leader holder onto the wrist and secure it with a screw.
+
+
+
+**Step 25: Attach Handle**
+- Attach the handle to motor 5 using 4 screws.
+
+
+
+**Step 26: Install Gripper Motor**
+- Insert the gripper motor, secure it with 3 screws on each side, attach a motor horn using a horn screw, and connect the motor wire.
+
+
+
+**Step 27: Attach Trigger**
+- Attach the follower trigger with 4 screws.
+
+
+
+**Step 28: Mount Controller**
+- Attach the motor controller on the back.
+
+
+
+
+
+
+*Assembly complete – proceed to calibration.*
+
## E. Calibrate
@@ -255,8 +448,8 @@ Next, you'll need to calibrate your SO-100 robot to ensure that the leader and f
You will need to move the follower arm to these positions sequentially:
-| 1. Zero position | 2. Rotated position | 3. Rest position |
-|---|---|---|
+| 1. Zero position | 2. Rotated position | 3. Rest position |
+| ------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| | | |
Make sure both arms are connected and run this script to launch manual calibration:
@@ -271,8 +464,8 @@ python lerobot/scripts/control_robot.py \
#### b. Manual calibration of leader arm
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
-| 1. Zero position | 2. Rotated position | 3. Rest position |
-|---|---|---|
+| 1. Zero position | 2. Rotated position | 3. Rest position |
+| ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
| | | |
Run this script to launch manual calibration:
@@ -335,7 +528,7 @@ python lerobot/scripts/control_robot.py \
--control.push_to_hub=true
```
-Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`.
+Note: You can resume recording by adding `--control.resume=true`.
## H. Visualize a dataset
@@ -363,8 +556,6 @@ python lerobot/scripts/control_robot.py \
--control.episode=0
```
-Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
-
## J. Train a policy
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
@@ -374,20 +565,25 @@ python lerobot/scripts/train.py \
--policy.type=act \
--output_dir=outputs/train/act_so100_test \
--job_name=act_so100_test \
- --device=cuda \
+ --policy.device=cuda \
--wandb.enable=true
```
-Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
-
Let's explain it:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
-4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
+4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
+To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so100_test` policy:
+```bash
+python lerobot/scripts/train.py \
+ --config_path=outputs/train/act_so100_test/checkpoints/last/pretrained_model/train_config.json \
+ --resume=true
+```
+
## K. Evaluate your policy
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
diff --git a/examples/11_use_lekiwi.md b/examples/11_use_lekiwi.md
index f10a9396..215419e1 100644
--- a/examples/11_use_lekiwi.md
+++ b/examples/11_use_lekiwi.md
@@ -23,6 +23,9 @@ Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains th
Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
+### Wired version
+If you have the **wired** LeKiwi version you can skip the installation of the Raspberry Pi and setting up SSH. You can also run all commands directly on your PC for both the LeKiwi scripts and the leader arm scripts for teleoperating.
+
## B. Install software on Pi
Now we have to setup the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board.
@@ -66,7 +69,7 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
#### 5. Install LeRobot with dependencies for the feetech motors:
```bash
-cd ~/lerobot && pip install -e ".[feetech]"
+cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
```
## C. Install LeRobot on laptop
@@ -107,15 +110,9 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
#### 5. Install LeRobot with dependencies for the feetech motors:
```bash
-cd ~/lerobot && pip install -e ".[feetech]"
+cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
```
-*EXTRA: For Linux only (not Mac)*: install extra dependencies for recording datasets:
-```bash
-conda install -y -c conda-forge ffmpeg
-pip uninstall -y opencv-python
-conda install -y -c conda-forge "opencv>=4.10.0"
-```
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:.
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
@@ -185,7 +182,7 @@ sudo chmod 666 /dev/ttyACM1
#### d. Update config file
-IMPORTANTLY: Now that you have your ports of leader and follower arm and ip adress of the mobile-so100, update the **ip** in Network configuration, **port** in leader_arms and **port** in lekiwi. In the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py) file. Where you will find something like:
+IMPORTANTLY: Now that you have your ports of leader and follower arm and ip address of the mobile-so100, update the **ip** in Network configuration, **port** in leader_arms and **port** in lekiwi. In the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py) file. Where you will find something like:
```python
@RobotConfig.register_subclass("lekiwi")
@dataclass
@@ -246,6 +243,110 @@ class LeKiwiRobotConfig(RobotConfig):
}
)
+ teleop_keys: dict[str, str] = field(
+ default_factory=lambda: {
+ # Movement
+ "forward": "w",
+ "backward": "s",
+ "left": "a",
+ "right": "d",
+ "rotate_left": "z",
+ "rotate_right": "x",
+ # Speed control
+ "speed_up": "r",
+ "speed_down": "f",
+ # quit teleop
+ "quit": "q",
+ }
+ )
+
+ mock: bool = False
+```
+
+## Wired version
+
+For the wired LeKiwi version your configured IP address should refer to your own laptop (127.0.0.1), because leader arm and LeKiwi are in this case connected to own laptop. Below and example configuration for this wired setup:
+```python
+@RobotConfig.register_subclass("lekiwi")
+@dataclass
+class LeKiwiRobotConfig(RobotConfig):
+ # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
+ # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
+ # the number of motors in your follower arms.
+ max_relative_target: int | None = None
+
+ # Network Configuration
+ ip: str = "127.0.0.1"
+ port: int = 5555
+ video_port: int = 5556
+
+ cameras: dict[str, CameraConfig] = field(
+ default_factory=lambda: {
+ "front": OpenCVCameraConfig(
+ camera_index=0, fps=30, width=640, height=480, rotation=90
+ ),
+ "wrist": OpenCVCameraConfig(
+ camera_index=1, fps=30, width=640, height=480, rotation=180
+ ),
+ }
+ )
+
+ calibration_dir: str = ".cache/calibration/lekiwi"
+
+ leader_arms: dict[str, MotorsBusConfig] = field(
+ default_factory=lambda: {
+ "main": FeetechMotorsBusConfig(
+ port="/dev/tty.usbmodem585A0077581",
+ motors={
+ # name: (index, model)
+ "shoulder_pan": [1, "sts3215"],
+ "shoulder_lift": [2, "sts3215"],
+ "elbow_flex": [3, "sts3215"],
+ "wrist_flex": [4, "sts3215"],
+ "wrist_roll": [5, "sts3215"],
+ "gripper": [6, "sts3215"],
+ },
+ ),
+ }
+ )
+
+ follower_arms: dict[str, MotorsBusConfig] = field(
+ default_factory=lambda: {
+ "main": FeetechMotorsBusConfig(
+ port="/dev/tty.usbmodem58760431061",
+ motors={
+ # name: (index, model)
+ "shoulder_pan": [1, "sts3215"],
+ "shoulder_lift": [2, "sts3215"],
+ "elbow_flex": [3, "sts3215"],
+ "wrist_flex": [4, "sts3215"],
+ "wrist_roll": [5, "sts3215"],
+ "gripper": [6, "sts3215"],
+ "left_wheel": (7, "sts3215"),
+ "back_wheel": (8, "sts3215"),
+ "right_wheel": (9, "sts3215"),
+ },
+ ),
+ }
+ )
+
+ teleop_keys: dict[str, str] = field(
+ default_factory=lambda: {
+ # Movement
+ "forward": "w",
+ "backward": "s",
+ "left": "a",
+ "right": "d",
+ "rotate_left": "z",
+ "rotate_right": "x",
+ # Speed control
+ "speed_up": "r",
+ "speed_down": "f",
+ # quit teleop
+ "quit": "q",
+ }
+ )
+
mock: bool = False
```
@@ -259,8 +360,8 @@ Now we have to calibrate the leader arm and the follower arm. The wheel motors d
You will need to move the follower arm to these positions sequentially:
-| 1. Zero position | 2. Rotated position | 3. Rest position |
-|---|---|---|
+| 1. Zero position | 2. Rotated position | 3. Rest position |
+| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| | | |
Make sure the arm is connected to the Raspberry Pi and run this script (on the Raspberry Pi) to launch manual calibration:
@@ -272,11 +373,14 @@ python lerobot/scripts/control_robot.py \
--control.arms='["main_follower"]'
```
+### Wired version
+If you have the **wired** LeKiwi version please run all commands including this calibration command on your laptop.
+
### Calibrate leader arm
Then to calibrate the leader arm (which is attached to the laptop/pc). You will need to move the leader arm to these positions sequentially:
-| 1. Zero position | 2. Rotated position | 3. Rest position |
-|---|---|---|
+| 1. Zero position | 2. Rotated position | 3. Rest position |
+| ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
| | | |
Run this script (on your laptop/pc) to launch manual calibration:
@@ -289,6 +393,10 @@ python lerobot/scripts/control_robot.py \
```
# F. Teleoperate
+
+> [!TIP]
+> If you're using a Mac, you might need to give Terminal permission to access your keyboard. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal.
+
To teleoperate SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
```bash
python lerobot/scripts/control_robot.py \
@@ -306,25 +414,28 @@ python lerobot/scripts/control_robot.py \
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
-|------------|-------------------|-----------------------|
-| Fast | 0.4 | 90 |
-| Medium | 0.25 | 60 |
-| Slow | 0.1 | 30 |
+| ---------- | ------------------ | ---------------------- |
+| Fast | 0.4 | 90 |
+| Medium | 0.25 | 60 |
+| Slow | 0.1 | 30 |
-| Key | Action |
-|------|--------------------------------|
-| W | Move forward |
-| A | Move left |
-| S | Move backward |
-| D | Move right |
-| Z | Turn left |
-| X | Turn right |
-| R | Increase speed |
-| F | Decrease speed |
+| Key | Action |
+| --- | -------------- |
+| W | Move forward |
+| A | Move left |
+| S | Move backward |
+| D | Move right |
+| Z | Turn left |
+| X | Turn right |
+| R | Increase speed |
+| F | Decrease speed |
> [!TIP]
-> If you use a different keyboard you can change the keys for each commmand in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
+> If you use a different keyboard you can change the keys for each command in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
+
+### Wired version
+If you have the **wired** LeKiwi version please run all commands including both these teleoperation commands on your laptop.
## Troubleshoot communication
@@ -364,6 +475,13 @@ Make sure the configuration file on both your laptop/pc and the Raspberry Pi is
# G. Record a dataset
Once you're familiar with teleoperation, you can record your first dataset with LeKiwi.
+To start the program on LeKiwi, SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
+```bash
+python lerobot/scripts/control_robot.py \
+ --robot.type=lekiwi \
+ --control.type=remote_robot
+```
+
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
@@ -374,8 +492,7 @@ Store your Hugging Face repository name in a variable to run these commands:
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
-
-Record 2 episodes and upload your dataset to the hub:
+On your laptop then run this command to record 2 episodes and upload your dataset to the hub:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=lekiwi \
@@ -391,7 +508,10 @@ python lerobot/scripts/control_robot.py \
--control.push_to_hub=true
```
-Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`.
+Note: You can resume recording by adding `--control.resume=true`.
+
+### Wired version
+If you have the **wired** LeKiwi version please run all commands including both these record dataset commands on your laptop.
# H. Visualize a dataset
@@ -418,8 +538,6 @@ python lerobot/scripts/control_robot.py \
--control.episode=0
```
-Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
-
## J. Train a policy
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
@@ -429,16 +547,14 @@ python lerobot/scripts/train.py \
--policy.type=act \
--output_dir=outputs/train/act_lekiwi_test \
--job_name=act_lekiwi_test \
- --device=cuda \
+ --policy.device=cuda \
--wandb.enable=true
```
-Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
-
Let's explain it:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/lekiwi_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
-4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
+4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
Training should take several hours. You will find checkpoints in `outputs/train/act_lekiwi_test/checkpoints`.
diff --git a/examples/11_use_moss.md b/examples/11_use_moss.md
index e35ba9b2..7b1be232 100644
--- a/examples/11_use_moss.md
+++ b/examples/11_use_moss.md
@@ -2,7 +2,7 @@ This tutorial explains how to use [Moss v1](https://github.com/jess-moss/moss-ro
## Source the parts
-Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
+Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials with link to source the parts, as well as the instructions to 3D print the parts and advice if it's your first time printing or if you don't own a 3D printer already.
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
@@ -33,14 +33,7 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
5. Install LeRobot with dependencies for the feetech motors:
```bash
-cd ~/lerobot && pip install -e ".[feetech]"
-```
-
-For Linux only (not Mac), install extra dependencies for recording datasets:
-```bash
-conda install -y -c conda-forge ffmpeg
-pip uninstall -y opencv-python
-conda install -y -c conda-forge "opencv>=4.10.0"
+cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
```
## Configure the motors
@@ -176,8 +169,8 @@ Next, you'll need to calibrate your Moss v1 robot to ensure that the leader and
You will need to move the follower arm to these positions sequentially:
-| 1. Zero position | 2. Rotated position | 3. Rest position |
-|---|---|---|
+| 1. Zero position | 2. Rotated position | 3. Rest position |
+| ------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| | | |
Make sure both arms are connected and run this script to launch manual calibration:
@@ -192,8 +185,8 @@ python lerobot/scripts/control_robot.py \
**Manual calibration of leader arm**
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
-| 1. Zero position | 2. Rotated position | 3. Rest position |
-|---|---|---|
+| 1. Zero position | 2. Rotated position | 3. Rest position |
+| ------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
| | | |
Run this script to launch manual calibration:
@@ -256,7 +249,7 @@ python lerobot/scripts/control_robot.py \
--control.push_to_hub=true
```
-Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`.
+Note: You can resume recording by adding `--control.resume=true`.
## Visualize a dataset
@@ -284,8 +277,6 @@ python lerobot/scripts/control_robot.py \
--control.episode=0
```
-Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
-
## Train a policy
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
@@ -295,16 +286,14 @@ python lerobot/scripts/train.py \
--policy.type=act \
--output_dir=outputs/train/act_moss_test \
--job_name=act_moss_test \
- --device=cuda \
+ --policy.device=cuda \
--wandb.enable=true
```
-Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
-
Let's explain it:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
-4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
+4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
diff --git a/examples/1_load_lerobot_dataset.py b/examples/1_load_lerobot_dataset.py
index 96c104b6..c374a375 100644
--- a/examples/1_load_lerobot_dataset.py
+++ b/examples/1_load_lerobot_dataset.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py
index 0a7b8deb..24b5ea2c 100644
--- a/examples/2_evaluate_pretrained_policy.py
+++ b/examples/2_evaluate_pretrained_policy.py
@@ -1,10 +1,24 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
```bash
-pip install -e ".[pusht]"`
+pip install --no-binary=av -e ".[pusht]"`
```
"""
@@ -30,7 +44,7 @@ pretrained_policy_path = "lerobot/diffusion_pusht"
# OR a path to a local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
-policy = DiffusionPolicy.from_pretrained(pretrained_policy_path, map_location=device)
+policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
# Initialize evaluation environment to render two observation types:
# an image of the scene and state/position of the agent. The environment
diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py
index cf5d4d3e..6c3af54e 100644
--- a/examples/3_train_policy.py
+++ b/examples/3_train_policy.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
Once you have trained a model with this script, you can try to evaluate it on
@@ -85,7 +99,7 @@ def main():
done = False
while not done:
for batch in dataloader:
- batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
+ batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
diff --git a/examples/4_train_policy_with_script.md b/examples/4_train_policy_with_script.md
index 58ed239a..b23d2271 100644
--- a/examples/4_train_policy_with_script.md
+++ b/examples/4_train_policy_with_script.md
@@ -1,5 +1,5 @@
This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run.
-> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--device=cpu` (`--device=mps` respectively). However, be advised that the code executes much slower on cpu.
+> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--policy.device=cpu` (`--policy.device=mps` respectively). However, be advised that the code executes much slower on cpu.
## The training script
diff --git a/examples/7_get_started_with_real_robot.md b/examples/7_get_started_with_real_robot.md
index e57d783a..5b12e903 100644
--- a/examples/7_get_started_with_real_robot.md
+++ b/examples/7_get_started_with_real_robot.md
@@ -33,7 +33,7 @@ First, install the additional dependencies required for robots built with dynami
Using `pip`:
```bash
-pip install -e ".[dynamixel]"
+pip install --no-binary=av -e ".[dynamixel]"
```
Using `poetry`:
@@ -46,13 +46,6 @@ Using `uv`:
uv sync --extra "dynamixel"
```
-/!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands:
-```bash
-conda install -c conda-forge ffmpeg
-pip uninstall opencv-python
-conda install -c conda-forge "opencv>=4.10.0"
-```
-
You are now ready to plug the 5V power supply to the motor bus of the leader arm (the smaller one) since all its motors only require 5V.
Then plug the 12V power supply to the motor bus of the follower arm. It has two motors that need 12V, and the rest will be powered with 5V through the voltage convertor.
@@ -292,6 +285,11 @@ Steps:
- Scan for devices. All 12 motors should appear.
- Select the motors one by one and move the arm. Check that the graphical indicator near the top right shows the movement.
+** There is a common issue with the Dynamixel XL430-W250 motors where the motors become undiscoverable after upgrading their firmware from Mac and Windows Dynamixel Wizard2 applications. When this occurs, it is required to do a firmware recovery (Select `DYNAMIXEL Firmware Recovery` and follow the prompts). There are two known workarounds to conduct this firmware reset:
+ 1) Install the Dynamixel Wizard on a linux machine and complete the firmware recovery
+ 2) Use the Dynamixel U2D2 in order to perform the reset with Windows or Mac. This U2D2 can be purchased [here](https://www.robotis.us/u2d2/).
+ For either solution, open DYNAMIXEL Wizard 2.0 and select the appropriate port. You will likely be unable to see the motor in the GUI at this time. Select `Firmware Recovery`, carefully choose the correct model, and wait for the process to complete. Finally, re-scan to confirm the firmware recovery was successful.
+
**Read and Write with DynamixelMotorsBus**
To get familiar with how `DynamixelMotorsBus` communicates with the motors, you can start by reading data from them. Copy past this code in the same interactive python session:
@@ -386,19 +384,19 @@ When you connect your robot for the first time, the [`ManipulatorRobot`](../lero
Here are the positions you'll move the follower arm to:
-| 1. Zero position | 2. Rotated position | 3. Rest position |
-|---|---|---|
+| 1. Zero position | 2. Rotated position | 3. Rest position |
+| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| | | |
And here are the corresponding positions for the leader arm:
-| 1. Zero position | 2. Rotated position | 3. Rest position |
-|---|---|---|
+| 1. Zero position | 2. Rotated position | 3. Rest position |
+| ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
| | | |
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
-During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to mesure if the values changed negatively or positively.
+During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to measure if the values changed negatively or positively.
Finally, the rest position ensures that the follower and leader arms are roughly aligned after calibration, preventing sudden movements that could damage the motors when starting teleoperation.
@@ -626,7 +624,7 @@ Finally, run this code to instantiate and connectyour camera:
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
-camera_config = OpenCVCameraConfig(camera_index=0)
+config = OpenCVCameraConfig(camera_index=0)
camera = OpenCVCamera(config)
camera.connect()
color_image = camera.read()
@@ -663,18 +661,20 @@ camera.disconnect()
**Instantiate your robot with cameras**
-Additionaly, you can set up your robot to work with your cameras.
+Additionally, you can set up your robot to work with your cameras.
Modify the following Python code with the appropriate camera names and configurations:
```python
robot = ManipulatorRobot(
- leader_arms={"main": leader_arm},
- follower_arms={"main": follower_arm},
- calibration_dir=".cache/calibration/koch",
- cameras={
- "laptop": OpenCVCameraConfig(0, fps=30, width=640, height=480),
- "phone": OpenCVCameraConfig(1, fps=30, width=640, height=480),
- },
+ KochRobotConfig(
+ leader_arms={"main": leader_arm},
+ follower_arms={"main": follower_arm},
+ calibration_dir=".cache/calibration/koch",
+ cameras={
+ "laptop": OpenCVCameraConfig(0, fps=30, width=640, height=480),
+ "phone": OpenCVCameraConfig(1, fps=30, width=640, height=480),
+ },
+ )
)
robot.connect()
```
@@ -711,7 +711,7 @@ python lerobot/scripts/control_robot.py \
You will see a lot of lines appearing like this one:
```
-INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtRfoll: 0.19 (5239.0hz)
+INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtWfoll: 0.19 (5239.0hz)
```
It contains
@@ -768,7 +768,7 @@ You can use the `record` function from [`lerobot/scripts/control_robot.py`](../l
1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of each episode recording.
2. Video streams from cameras are displayed in window so that you can verify them.
3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--control.push_to_hub=false` is provided).
-4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You might need to add `--control.local_files_only=true` if your dataset was not uploaded to hugging face hub. Also you will need to manually delete the dataset directory to start recording from scratch.
+4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You will need to manually delete the dataset directory if you want to start recording from scratch.
5. Set the flow of data recording using command line arguments:
- `--control.warmup_time_s=10` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default).
- `--control.episode_time_s=60` defines the number of seconds for data recording for each episode (60 seconds by default).
@@ -823,15 +823,10 @@ It contains:
- `dtRlead: 5.06 (197.5hz)` which is the delta time of reading the present position of the leader arm.
- `dtWfoll: 0.25 (3963.7hz)` which is the delta time of writing the goal position on the follower arm ; writing is asynchronous so it takes less time than reading.
- `dtRfoll: 6.22 (160.7hz)` which is the delta time of reading the present position on the follower arm.
-- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchrously.
-- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchrously.
+- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchronously.
+- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
Troubleshooting:
-- On Linux, if you encounter a hanging issue when using cameras, uninstall opencv and re-install it with conda:
-```bash
-pip uninstall opencv-python
-conda install -c conda-forge opencv=4.10.0
-```
- On Linux, if you encounter any issue during video encoding with `ffmpeg: unknown encoder libsvtav1`, you can:
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
- or, install [Homebrew](https://brew.sh) and run `brew install ffmpeg` (it should be compiled with `libsvtav1`),
@@ -844,7 +839,7 @@ At the end of data recording, your dataset will be uploaded on your Hugging Face
echo https://huggingface.co/datasets/${HF_USER}/koch_test
```
-### b. Advices for recording dataset
+### b. Advice for recording dataset
Once you're comfortable with data recording, it's time to create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings.
@@ -883,8 +878,6 @@ python lerobot/scripts/control_robot.py \
--control.episode=0
```
-Note: You might need to add `--control.local_files_only=true` if your dataset was not uploaded to hugging face hub.
-
Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com).
## 4. Train a policy on your data
@@ -898,16 +891,14 @@ python lerobot/scripts/train.py \
--policy.type=act \
--output_dir=outputs/train/act_koch_test \
--job_name=act_koch_test \
- --device=cuda \
+ --policy.device=cuda \
--wandb.enable=true
```
-Note: You might need to add `--dataset.local_files_only=true` if your dataset was not uploaded to hugging face hub.
-
Let's explain it:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
-4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
+4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
diff --git a/examples/8_use_stretch.md b/examples/8_use_stretch.md
index 2f8c0ffb..d02e7ef3 100644
--- a/examples/8_use_stretch.md
+++ b/examples/8_use_stretch.md
@@ -45,18 +45,11 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
6. Install LeRobot with stretch dependencies:
```bash
-cd ~/lerobot && pip install -e ".[stretch]"
+cd ~/lerobot && pip install --no-binary=av -e ".[stretch]"
```
> **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.`
-For Linux only (not Mac), install extra dependencies for recording datasets:
-```bash
-conda install -y -c conda-forge ffmpeg
-pip uninstall -y opencv-python
-conda install -y -c conda-forge "opencv>=4.10.0"
-```
-
7. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
```bash
stretch_system_check.py
@@ -98,7 +91,7 @@ python lerobot/scripts/control_robot.py \
```
This is equivalent to running `stretch_robot_home.py`
-> **Note:** If you run any of the LeRobot scripts below and Stretch is not poperly homed, it will automatically home/calibrate first.
+> **Note:** If you run any of the LeRobot scripts below and Stretch is not properly homed, it will automatically home/calibrate first.
**Teleoperate**
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
diff --git a/examples/9_use_aloha.md b/examples/9_use_aloha.md
index d74c8b7a..1f7aee3c 100644
--- a/examples/9_use_aloha.md
+++ b/examples/9_use_aloha.md
@@ -2,7 +2,7 @@ This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.tro
## Setup
-Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
+Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/2.0/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
## Install LeRobot
@@ -32,14 +32,7 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
5. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
```bash
-cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]"
-```
-
-For Linux only (not Mac), install extra dependencies for recording datasets:
-```bash
-conda install -y -c conda-forge ffmpeg
-pip uninstall -y opencv-python
-conda install -y -c conda-forge "opencv>=4.10.0"
+cd ~/lerobot && pip install --no-binary=av -e ".[dynamixel, intelrealsense]"
```
## Teleoperate
@@ -135,14 +128,14 @@ python lerobot/scripts/train.py \
--policy.type=act \
--output_dir=outputs/train/act_aloha_test \
--job_name=act_aloha_test \
- --device=cuda \
+ --policy.device=cuda \
--wandb.enable=true
```
Let's explain it:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
-4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
+4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
@@ -172,10 +165,10 @@ python lerobot/scripts/control_robot.py \
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`).
-3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constent 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
+3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constant 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
## More
-Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explaination.
+Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explanation.
If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`.
diff --git a/examples/advanced/1_add_image_transforms.py b/examples/advanced/1_add_image_transforms.py
index 882710e3..f1460926 100644
--- a/examples/advanced/1_add_image_transforms.py
+++ b/examples/advanced/1_add_image_transforms.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
diff --git a/examples/advanced/2_calculate_validation_loss.py b/examples/advanced/2_calculate_validation_loss.py
index 6f234719..47b4dd02 100644
--- a/examples/advanced/2_calculate_validation_loss.py
+++ b/examples/advanced/2_calculate_validation_loss.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
This technique can be useful for debugging and testing purposes, as well as identifying whether a policy
diff --git a/examples/port_datasets/pusht_zarr.py b/examples/port_datasets/pusht_zarr.py
deleted file mode 100644
index 1506f427..00000000
--- a/examples/port_datasets/pusht_zarr.py
+++ /dev/null
@@ -1,222 +0,0 @@
-import shutil
-from pathlib import Path
-
-import numpy as np
-import torch
-
-from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
-from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
-
-PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
-PUSHT_FEATURES = {
- "observation.state": {
- "dtype": "float32",
- "shape": (2,),
- "names": {
- "axes": ["x", "y"],
- },
- },
- "action": {
- "dtype": "float32",
- "shape": (2,),
- "names": {
- "axes": ["x", "y"],
- },
- },
- "next.reward": {
- "dtype": "float32",
- "shape": (1,),
- "names": None,
- },
- "next.success": {
- "dtype": "bool",
- "shape": (1,),
- "names": None,
- },
- "observation.environment_state": {
- "dtype": "float32",
- "shape": (16,),
- "names": [
- "keypoints",
- ],
- },
- "observation.image": {
- "dtype": None,
- "shape": (3, 96, 96),
- "names": [
- "channels",
- "height",
- "width",
- ],
- },
-}
-
-
-def build_features(mode: str) -> dict:
- features = PUSHT_FEATURES
- if mode == "keypoints":
- features.pop("observation.image")
- else:
- features.pop("observation.environment_state")
- features["observation.image"]["dtype"] = mode
-
- return features
-
-
-def load_raw_dataset(zarr_path: Path):
- try:
- from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
- ReplayBuffer as DiffusionPolicyReplayBuffer,
- )
- except ModuleNotFoundError as e:
- print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
- raise e
-
- zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
- return zarr_data
-
-
-def calculate_coverage(zarr_data):
- try:
- import pymunk
- from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
- except ModuleNotFoundError as e:
- print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
- raise e
-
- block_pos = zarr_data["state"][:, 2:4]
- block_angle = zarr_data["state"][:, 4]
-
- num_frames = len(block_pos)
-
- coverage = np.zeros((num_frames,))
- # 8 keypoints with 2 coords each
- keypoints = np.zeros((num_frames, 16))
-
- # Set x, y, theta (in radians)
- goal_pos_angle = np.array([256, 256, np.pi / 4])
- goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
-
- for i in range(num_frames):
- space = pymunk.Space()
- space.gravity = 0, 0
- space.damping = 0
-
- # Add walls.
- walls = [
- PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
- PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
- PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
- PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
- ]
- space.add(*walls)
-
- block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
- goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
- block_geom = pymunk_to_shapely(block_body, block_body.shapes)
- intersection_area = goal_geom.intersection(block_geom).area
- goal_area = goal_geom.area
- coverage[i] = intersection_area / goal_area
- keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
-
- return coverage, keypoints
-
-
-def calculate_success(coverage: float, success_threshold: float):
- return coverage > success_threshold
-
-
-def calculate_reward(coverage: float, success_threshold: float):
- return np.clip(coverage / success_threshold, 0, 1)
-
-
-def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = True):
- if mode not in ["video", "image", "keypoints"]:
- raise ValueError(mode)
-
- if (LEROBOT_HOME / repo_id).exists():
- shutil.rmtree(LEROBOT_HOME / repo_id)
-
- if not raw_dir.exists():
- download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
-
- zarr_data = load_raw_dataset(zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr")
-
- env_state = zarr_data["state"][:]
- agent_pos = env_state[:, :2]
-
- action = zarr_data["action"][:]
- image = zarr_data["img"] # (b, h, w, c)
-
- episode_data_index = {
- "from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
- "to": zarr_data.meta["episode_ends"],
- }
-
- # Calculate success and reward based on the overlapping area
- # of the T-object and the T-area.
- coverage, keypoints = calculate_coverage(zarr_data)
- success = calculate_success(coverage, success_threshold=0.95)
- reward = calculate_reward(coverage, success_threshold=0.95)
-
- features = build_features(mode)
- dataset = LeRobotDataset.create(
- repo_id=repo_id,
- fps=10,
- robot_type="2d pointer",
- features=features,
- image_writer_threads=4,
- )
- episodes = range(len(episode_data_index["from"]))
- for ep_idx in episodes:
- from_idx = episode_data_index["from"][ep_idx]
- to_idx = episode_data_index["to"][ep_idx]
- num_frames = to_idx - from_idx
-
- for frame_idx in range(num_frames):
- i = from_idx + frame_idx
- frame = {
- "action": torch.from_numpy(action[i]),
- # Shift reward and success by +1 until the last item of the episode
- "next.reward": reward[i + (frame_idx < num_frames - 1)],
- "next.success": success[i + (frame_idx < num_frames - 1)],
- }
-
- frame["observation.state"] = torch.from_numpy(agent_pos[i])
-
- if mode == "keypoints":
- frame["observation.environment_state"] = torch.from_numpy(keypoints[i])
- else:
- frame["observation.image"] = torch.from_numpy(image[i])
-
- dataset.add_frame(frame)
-
- dataset.save_episode(task=PUSHT_TASK)
-
- dataset.consolidate()
-
- if push_to_hub:
- dataset.push_to_hub()
-
-
-if __name__ == "__main__":
- # To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
- repo_id = "lerobot/pusht"
-
- modes = ["video", "image", "keypoints"]
- # Uncomment if you want to try with a specific mode
- # modes = ["video"]
- # modes = ["image"]
- # modes = ["keypoints"]
-
- raw_dir = Path("data/lerobot-raw/pusht_raw")
- for mode in modes:
- if mode in ["image", "keypoints"]:
- repo_id += f"_{mode}"
-
- # download and load raw dataset, create LeRobotDataset, populate it, push to hub
- main(raw_dir, repo_id=repo_id, mode=mode)
-
- # Uncomment if you want to load the local dataset and explore it
- # dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
- # breakpoint()
diff --git a/lerobot/common/constants.py b/lerobot/common/constants.py
index 34da4ac0..973595cd 100644
--- a/lerobot/common/constants.py
+++ b/lerobot/common/constants.py
@@ -1,4 +1,22 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
# keys
+import os
+from pathlib import Path
+
+from huggingface_hub.constants import HF_HOME
+
OBS_ENV = "observation.environment_state"
OBS_ROBOT = "observation.state"
OBS_IMAGE = "observation.image"
@@ -15,3 +33,13 @@ TRAINING_STEP = "training_step.json"
OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"
+
+# cache dir
+default_cache_path = Path(HF_HOME) / "lerobot"
+HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
+
+if "LEROBOT_HOME" in os.environ:
+ raise ValueError(
+ f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
+ "'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
+ )
diff --git a/lerobot/common/datasets/backward_compatibility.py b/lerobot/common/datasets/backward_compatibility.py
new file mode 100644
index 00000000..cf8e31c4
--- /dev/null
+++ b/lerobot/common/datasets/backward_compatibility.py
@@ -0,0 +1,68 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import packaging.version
+
+V2_MESSAGE = """
+The dataset you requested ({repo_id}) is in {version} format.
+
+We introduced a new format since v2.0 which is not backward compatible with v1.x.
+Please, use our conversion script. Modify the following command with your own task description:
+```
+python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
+ --repo-id {repo_id} \\
+ --single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
+```
+
+A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
+peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
+cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
+target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
+sweatshirt.", ...
+
+If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
+or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
+"""
+
+V21_MESSAGE = """
+The dataset you requested ({repo_id}) is in {version} format.
+While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
+stats instead of per-episode stats. Update your dataset stats to the new format using this command:
+```
+python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
+```
+
+If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
+or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
+"""
+
+FUTURE_MESSAGE = """
+The dataset you requested ({repo_id}) is only available in {version} format.
+As we cannot ensure forward compatibility with it, please update your current version of lerobot.
+"""
+
+
+class CompatibilityError(Exception): ...
+
+
+class BackwardCompatibilityError(CompatibilityError):
+ def __init__(self, repo_id: str, version: packaging.version.Version):
+ message = V2_MESSAGE.format(repo_id=repo_id, version=version)
+ super().__init__(message)
+
+
+class ForwardCompatibilityError(CompatibilityError):
+ def __init__(self, repo_id: str, version: packaging.version.Version):
+ message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
+ super().__init__(message)
diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py
index c6211699..1149ec83 100644
--- a/lerobot/common/datasets/compute_stats.py
+++ b/lerobot/common/datasets/compute_stats.py
@@ -13,202 +13,164 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from copy import deepcopy
-from math import ceil
+import numpy as np
-import einops
-import torch
-import tqdm
+from lerobot.common.datasets.utils import load_image_as_numpy
-def get_stats_einops_patterns(dataset, num_workers=0):
- """These einops patterns will be used to aggregate batches and compute statistics.
+def estimate_num_samples(
+ dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
+) -> int:
+ """Heuristic to estimate the number of samples based on dataset size.
+ The power controls the sample growth relative to dataset size.
+ Lower the power for less number of samples.
- Note: We assume the images are in channel first format
+ For default arguments, we have:
+ - from 1 to ~500, num_samples=100
+ - at 1000, num_samples=177
+ - at 2000, num_samples=299
+ - at 5000, num_samples=594
+ - at 10000, num_samples=1000
+ - at 20000, num_samples=1681
"""
+ if dataset_len < min_num_samples:
+ min_num_samples = dataset_len
+ return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
- dataloader = torch.utils.data.DataLoader(
- dataset,
- num_workers=num_workers,
- batch_size=2,
- shuffle=False,
- )
- batch = next(iter(dataloader))
- stats_patterns = {}
+def sample_indices(data_len: int) -> list[int]:
+ num_samples = estimate_num_samples(data_len)
+ return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
- for key in dataset.features:
- # sanity check that tensors are not float64
- assert batch[key].dtype != torch.float64
- # if isinstance(feats_type, (VideoFrame, Image)):
- if key in dataset.meta.camera_keys:
- # sanity check that images are channel first
- _, c, h, w = batch[key].shape
- assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
+def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
+ _, height, width = img.shape
- # sanity check that images are float32 in range [0,1]
- assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
- assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
- assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
+ if max(width, height) < max_size_threshold:
+ # no downsampling needed
+ return img
- stats_patterns[key] = "b c h w -> c 1 1"
- elif batch[key].ndim == 2:
- stats_patterns[key] = "b c -> c "
- elif batch[key].ndim == 1:
- stats_patterns[key] = "b -> 1"
+ downsample_factor = int(width / target_size) if width > height else int(height / target_size)
+ return img[:, ::downsample_factor, ::downsample_factor]
+
+
+def sample_images(image_paths: list[str]) -> np.ndarray:
+ sampled_indices = sample_indices(len(image_paths))
+
+ images = None
+ for i, idx in enumerate(sampled_indices):
+ path = image_paths[idx]
+ # we load as uint8 to reduce memory usage
+ img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
+ img = auto_downsample_height_width(img)
+
+ if images is None:
+ images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
+
+ images[i] = img
+
+ return images
+
+
+def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
+ return {
+ "min": np.min(array, axis=axis, keepdims=keepdims),
+ "max": np.max(array, axis=axis, keepdims=keepdims),
+ "mean": np.mean(array, axis=axis, keepdims=keepdims),
+ "std": np.std(array, axis=axis, keepdims=keepdims),
+ "count": np.array([len(array)]),
+ }
+
+
+def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
+ ep_stats = {}
+ for key, data in episode_data.items():
+ if features[key]["dtype"] == "string":
+ continue # HACK: we should receive np.arrays of strings
+ elif features[key]["dtype"] in ["image", "video"]:
+ ep_ft_array = sample_images(data) # data is a list of image paths
+ axes_to_reduce = (0, 2, 3) # keep channel dim
+ keepdims = True
else:
- raise ValueError(f"{key}, {batch[key].shape}")
+ ep_ft_array = data # data is already a np.ndarray
+ axes_to_reduce = 0 # compute stats over the first axis
+ keepdims = data.ndim == 1 # keep as np.array
- return stats_patterns
+ ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
+
+ # finally, we normalize and remove batch dim for images
+ if features[key]["dtype"] in ["image", "video"]:
+ ep_stats[key] = {
+ k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
+ }
+
+ return ep_stats
-def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
- """Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
- if max_num_samples is None:
- max_num_samples = len(dataset)
-
- # for more info on why we need to set the same number of workers, see `load_from_videos`
- stats_patterns = get_stats_einops_patterns(dataset, num_workers)
-
- # mean and std will be computed incrementally while max and min will track the running value.
- mean, std, max, min = {}, {}, {}, {}
- for key in stats_patterns:
- mean[key] = torch.tensor(0.0).float()
- std[key] = torch.tensor(0.0).float()
- max[key] = torch.tensor(-float("inf")).float()
- min[key] = torch.tensor(float("inf")).float()
-
- def create_seeded_dataloader(dataset, batch_size, seed):
- generator = torch.Generator()
- generator.manual_seed(seed)
- dataloader = torch.utils.data.DataLoader(
- dataset,
- num_workers=num_workers,
- batch_size=batch_size,
- shuffle=True,
- drop_last=False,
- generator=generator,
- )
- return dataloader
-
- # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
- # surprises when rerunning the sampler.
- first_batch = None
- running_item_count = 0 # for online mean computation
- dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
- for i, batch in enumerate(
- tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
- ):
- this_batch_size = len(batch["index"])
- running_item_count += this_batch_size
- if first_batch is None:
- first_batch = deepcopy(batch)
- for key, pattern in stats_patterns.items():
- batch[key] = batch[key].float()
- # Numerically stable update step for mean computation.
- batch_mean = einops.reduce(batch[key], pattern, "mean")
- # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
- # the update step, N is the running item count, B is this batch size, x̄ is the running mean,
- # and x is the current batch mean. Some rearrangement is then required to avoid risking
- # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
- # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
- mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
- max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
- min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
-
- if i == ceil(max_num_samples / batch_size) - 1:
- break
-
- first_batch_ = None
- running_item_count = 0 # for online std computation
- dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
- for i, batch in enumerate(
- tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
- ):
- this_batch_size = len(batch["index"])
- running_item_count += this_batch_size
- # Sanity check to make sure the batches are still in the same order as before.
- if first_batch_ is None:
- first_batch_ = deepcopy(batch)
- for key in stats_patterns:
- assert torch.equal(first_batch_[key], first_batch[key])
- for key, pattern in stats_patterns.items():
- batch[key] = batch[key].float()
- # Numerically stable update step for mean computation (where the mean is over squared
- # residuals).See notes in the mean computation loop above.
- batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
- std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
-
- if i == ceil(max_num_samples / batch_size) - 1:
- break
-
- for key in stats_patterns:
- std[key] = torch.sqrt(std[key])
-
- stats = {}
- for key in stats_patterns:
- stats[key] = {
- "mean": mean[key],
- "std": std[key],
- "max": max[key],
- "min": min[key],
- }
- return stats
+def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
+ for i in range(len(stats_list)):
+ for fkey in stats_list[i]:
+ for k, v in stats_list[i][fkey].items():
+ if not isinstance(v, np.ndarray):
+ raise ValueError(
+ f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
+ )
+ if v.ndim == 0:
+ raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
+ if k == "count" and v.shape != (1,):
+ raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
+ if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
+ raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
-def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
- """Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
+def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
+ """Aggregates stats for a single feature."""
+ means = np.stack([s["mean"] for s in stats_ft_list])
+ variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
+ counts = np.stack([s["count"] for s in stats_ft_list])
+ total_count = counts.sum(axis=0)
- The final stats will have the union of all data keys from each of the datasets.
+ # Prepare weighted mean by matching number of dimensions
+ while counts.ndim < means.ndim:
+ counts = np.expand_dims(counts, axis=-1)
- The final stats will have the union of all data keys from each of the datasets. For instance:
- - new_max = max(max_dataset_0, max_dataset_1, ...)
+ # Compute the weighted mean
+ weighted_means = means * counts
+ total_mean = weighted_means.sum(axis=0) / total_count
+
+ # Compute the variance using the parallel algorithm
+ delta_means = means - total_mean
+ weighted_variances = (variances + delta_means**2) * counts
+ total_variance = weighted_variances.sum(axis=0) / total_count
+
+ return {
+ "min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
+ "max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
+ "mean": total_mean,
+ "std": np.sqrt(total_variance),
+ "count": total_count,
+ }
+
+
+def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
+ """Aggregate stats from multiple compute_stats outputs into a single set of stats.
+
+ The final stats will have the union of all data keys from each of the stats dicts.
+
+ For instance:
- new_min = min(min_dataset_0, min_dataset_1, ...)
- - new_mean = (mean of all data)
+ - new_max = max(max_dataset_0, max_dataset_1, ...)
+ - new_mean = (mean of all data, weighted by counts)
- new_std = (std of all data)
"""
- data_keys = set()
- for dataset in ls_datasets:
- data_keys.update(dataset.meta.stats.keys())
- stats = {k: {} for k in data_keys}
- for data_key in data_keys:
- for stat_key in ["min", "max"]:
- # compute `max(dataset_0["max"], dataset_1["max"], ...)`
- stats[data_key][stat_key] = einops.reduce(
- torch.stack(
- [ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
- dim=0,
- ),
- "n ... -> ...",
- stat_key,
- )
- total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
- # Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
- # dataset, then divide by total_samples to get the overall "mean".
- # NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
- # numerical overflow!
- stats[data_key]["mean"] = sum(
- d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
- for d in ls_datasets
- if data_key in d.meta.stats
- )
- # The derivation for standard deviation is a little more involved but is much in the same spirit as
- # the computation of the mean.
- # Given two sets of data where the statistics are known:
- # σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
- # where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
- # NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
- # numerical overflow!
- stats[data_key]["std"] = torch.sqrt(
- sum(
- (
- d.meta.stats[data_key]["std"] ** 2
- + (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
- )
- * (d.num_frames / total_samples)
- for d in ls_datasets
- if data_key in d.meta.stats
- )
- )
- return stats
+
+ _assert_type_and_shape(stats_list)
+
+ data_keys = {key for stats in stats_list for key in stats}
+ aggregated_stats = {key: {} for key in data_keys}
+
+ for key in data_keys:
+ stats_with_key = [stats[key] for stats in stats_list if key in stats]
+ aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
+
+ return aggregated_stats
diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py
index 95ba76b8..38c01b42 100644
--- a/lerobot/common/datasets/factory.py
+++ b/lerobot/common/datasets/factory.py
@@ -83,15 +83,18 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
)
if isinstance(cfg.dataset.repo_id, str):
- ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only)
+ ds_meta = LeRobotDatasetMetadata(
+ cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
+ )
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
dataset = LeRobotDataset(
cfg.dataset.repo_id,
+ root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
+ revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
- local_files_only=cfg.dataset.local_files_only,
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py
index 85dd6830..6fc0ee2f 100644
--- a/lerobot/common/datasets/image_writer.py
+++ b/lerobot/common/datasets/image_writer.py
@@ -38,22 +38,40 @@ def safe_stop_image_writer(func):
return wrapper
-def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
+def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images
- if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
+ if image_array.ndim != 3:
+ raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
+
+ if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0)
+
+ elif image_array.shape[-1] != 3:
+ raise NotImplementedError(
+ f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
+ )
+
if image_array.dtype != np.uint8:
- # Assume the image is in [0, 1] range for floating-point data
- image_array = np.clip(image_array, 0, 1)
+ if range_check:
+ max_ = image_array.max().item()
+ min_ = image_array.min().item()
+ if max_ > 1.0 or min_ < 0.0:
+ raise ValueError(
+ "The image data type is float, which requires values in the range [0.0, 1.0]. "
+ f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
+ "provide a uint8 image with values in the range [0, 255]."
+ )
+
image_array = (image_array * 255).astype(np.uint8)
+
return PIL.Image.fromarray(image_array)
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
try:
if isinstance(image, np.ndarray):
- img = image_array_to_image(image)
+ img = image_array_to_pil_image(image)
elif isinstance(image, PIL.Image.Image):
img = image
else:
diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py
index 9483bf0a..6ef955dd 100644
--- a/lerobot/common/datasets/lerobot_dataset.py
+++ b/lerobot/common/datasets/lerobot_dataset.py
@@ -13,62 +13,68 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import contextlib
import logging
-import os
import shutil
-from functools import cached_property
from pathlib import Path
from typing import Callable
import datasets
import numpy as np
+import packaging.version
import PIL.Image
import torch
import torch.utils
-from datasets import load_dataset
-from huggingface_hub import create_repo, snapshot_download, upload_folder
+from datasets import concatenate_datasets, load_dataset
+from huggingface_hub import HfApi, snapshot_download
+from huggingface_hub.constants import REPOCARD_NAME
+from huggingface_hub.errors import RevisionNotFoundError
-from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
+from lerobot.common.constants import HF_LEROBOT_HOME
+from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import (
DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH,
- EPISODES_PATH,
INFO_PATH,
- STATS_PATH,
TASKS_PATH,
append_jsonlines,
+ backward_compatible_episodes_stats,
check_delta_timestamps,
check_timestamps_sync,
check_version_compatibility,
- create_branch,
create_empty_dataset_info,
create_lerobot_dataset_card,
+ embed_images,
get_delta_indices,
get_episode_data_index,
get_features_from_robot,
get_hf_features_from_features,
- get_hub_safe_version,
+ get_safe_version,
hf_transform_to_torch,
+ is_valid_version,
load_episodes,
+ load_episodes_stats,
load_info,
load_stats,
load_tasks,
- serialize_dict,
+ validate_episode_buffer,
+ validate_frame,
+ write_episode,
+ write_episode_stats,
+ write_info,
write_json,
- write_parquet,
)
from lerobot.common.datasets.video_utils import (
VideoFrame,
- decode_video_frames_torchvision,
+ decode_video_frames,
encode_video_frames,
+ get_safe_default_codec,
get_video_info,
)
from lerobot.common.robot_devices.robots.utils import Robot
-# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
-CODEBASE_VERSION = "v2.0"
-LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
+CODEBASE_VERSION = "v2.1"
class LeRobotDatasetMetadata:
@@ -76,19 +82,36 @@ class LeRobotDatasetMetadata:
self,
repo_id: str,
root: str | Path | None = None,
- local_files_only: bool = False,
+ revision: str | None = None,
+ force_cache_sync: bool = False,
):
self.repo_id = repo_id
- self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
- self.local_files_only = local_files_only
+ self.revision = revision if revision else CODEBASE_VERSION
+ self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
- # Load metadata
- (self.root / "meta").mkdir(exist_ok=True, parents=True)
- self.pull_from_repo(allow_patterns="meta/")
+ try:
+ if force_cache_sync:
+ raise FileNotFoundError
+ self.load_metadata()
+ except (FileNotFoundError, NotADirectoryError):
+ if is_valid_version(self.revision):
+ self.revision = get_safe_version(self.repo_id, self.revision)
+
+ (self.root / "meta").mkdir(exist_ok=True, parents=True)
+ self.pull_from_repo(allow_patterns="meta/")
+ self.load_metadata()
+
+ def load_metadata(self):
self.info = load_info(self.root)
- self.stats = load_stats(self.root)
- self.tasks = load_tasks(self.root)
+ check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
+ self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root)
+ if self._version < packaging.version.parse("v2.1"):
+ self.stats = load_stats(self.root)
+ self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
+ else:
+ self.episodes_stats = load_episodes_stats(self.root)
+ self.stats = aggregate_stats(list(self.episodes_stats.values()))
def pull_from_repo(
self,
@@ -98,21 +121,16 @@ class LeRobotDatasetMetadata:
snapshot_download(
self.repo_id,
repo_type="dataset",
- revision=self._hub_version,
+ revision=self.revision,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
- local_files_only=self.local_files_only,
)
- @cached_property
- def _hub_version(self) -> str | None:
- return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
-
@property
- def _version(self) -> str:
+ def _version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset."""
- return self.info["codebase_version"]
+ return packaging.version.parse(self.info["codebase_version"])
def get_data_file_path(self, ep_index: int) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
@@ -202,54 +220,65 @@ class LeRobotDatasetMetadata:
"""Max number of episodes per chunk."""
return self.info["chunks_size"]
- @property
- def task_to_task_index(self) -> dict:
- return {task: task_idx for task_idx, task in self.tasks.items()}
-
- def get_task_index(self, task: str) -> int:
+ def get_task_index(self, task: str) -> int | None:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
- otherwise creates a new task_index.
+ otherwise return None.
"""
- task_index = self.task_to_task_index.get(task, None)
- return task_index if task_index is not None else self.total_tasks
+ return self.task_to_task_index.get(task, None)
- def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
+ def add_task(self, task: str):
+ """
+ Given a task in natural language, add it to the dictionary of tasks.
+ """
+ if task in self.task_to_task_index:
+ raise ValueError(f"The task '{task}' already exists and can't be added twice.")
+
+ task_index = self.info["total_tasks"]
+ self.task_to_task_index[task] = task_index
+ self.tasks[task_index] = task
+ self.info["total_tasks"] += 1
+
+ task_dict = {
+ "task_index": task_index,
+ "task": task,
+ }
+ append_jsonlines(task_dict, self.root / TASKS_PATH)
+
+ def save_episode(
+ self,
+ episode_index: int,
+ episode_length: int,
+ episode_tasks: list[str],
+ episode_stats: dict[str, dict],
+ ) -> None:
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
- if task_index not in self.tasks:
- self.info["total_tasks"] += 1
- self.tasks[task_index] = task
- task_dict = {
- "task_index": task_index,
- "task": task,
- }
- append_jsonlines(task_dict, self.root / TASKS_PATH)
-
chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks:
self.info["total_chunks"] += 1
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys)
- write_json(self.info, self.root / INFO_PATH)
+ if len(self.video_keys) > 0:
+ self.update_video_info()
+
+ write_info(self.info, self.root)
episode_dict = {
"episode_index": episode_index,
- "tasks": [task],
+ "tasks": episode_tasks,
"length": episode_length,
}
- self.episodes.append(episode_dict)
- append_jsonlines(episode_dict, self.root / EPISODES_PATH)
+ self.episodes[episode_index] = episode_dict
+ write_episode(episode_dict, self.root)
- # TODO(aliberts): refactor stats in save_episodes
- # image_sampling = int(self.fps / 2) # sample 2 img/s for the stats
- # ep_stats = compute_episode_stats(episode_buffer, self.features, episode_length, image_sampling=image_sampling)
- # ep_stats = serialize_dict(ep_stats)
- # append_jsonlines(ep_stats, self.root / STATS_PATH)
+ self.episodes_stats[episode_index] = episode_stats
+ self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
+ write_episode_stats(episode_index, episode_stats, self.root)
- def write_video_info(self) -> None:
+ def update_video_info(self) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
@@ -259,8 +288,6 @@ class LeRobotDatasetMetadata:
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
self.info["features"][key]["info"] = get_video_info(video_path)
- write_json(self.info, self.root / INFO_PATH)
-
def __repr__(self):
feature_keys = list(self.features)
return (
@@ -286,7 +313,7 @@ class LeRobotDatasetMetadata:
"""Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls)
obj.repo_id = repo_id
- obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
+ obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
@@ -304,6 +331,7 @@ class LeRobotDatasetMetadata:
)
else:
# TODO(aliberts, rcadene): implement sanity check for features
+ features = {**features, **DEFAULT_FEATURES}
# check if none of the features contains a "/" in their names,
# as this would break the dict flattening in the stats computation, which uses '/' as separator
@@ -313,12 +341,13 @@ class LeRobotDatasetMetadata:
features = {**features, **DEFAULT_FEATURES}
- obj.tasks, obj.stats, obj.episodes = {}, {}, []
+ obj.tasks, obj.task_to_task_index = {}, {}
+ obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
- obj.local_files_only = True
+ obj.revision = None
return obj
@@ -331,8 +360,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
+ revision: str | None = None,
+ force_cache_sync: bool = False,
download_videos: bool = True,
- local_files_only: bool = False,
video_backend: str | None = None,
):
"""
@@ -342,7 +372,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
- On your local disk in the 'root' folder. This is typically the case when you recorded your
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
- internet connection), in that case, use local_files_only=True.
+ internet connection).
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
@@ -362,7 +392,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
- info contains various information about the dataset like shapes, keys, fps etc.
- stats stores the dataset statistics of the different modalities for normalization
- tasks contains the prompts for each task of the dataset, which can be used for
- task-conditionned training.
+ task-conditioned training.
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
@@ -424,24 +454,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
multiples of 1/fps. Defaults to 1e-4.
+ revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
+ commit hash. Defaults to current codebase version tag.
+ sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files
+ are already present in the local cache, this will be faster. However, files loaded might not
+ be in sync with the version on the hub, especially if you specified 'revision'. Defaults to
+ False.
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
- local_files_only (bool, optional): Flag to use local files only. If True, no requests to the hub
- will be made. Defaults to False.
- video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
- a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
+ video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
+ You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
"""
super().__init__()
self.repo_id = repo_id
- self.root = Path(root) if root else LEROBOT_HOME / repo_id
+ self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
- self.video_backend = video_backend if video_backend else "pyav"
+ self.revision = revision if revision else CODEBASE_VERSION
+ self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.delta_indices = None
- self.local_files_only = local_files_only
# Unused attributes
self.image_writer = None
@@ -450,64 +484,92 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.root.mkdir(exist_ok=True, parents=True)
# Load metadata
- self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
-
- # Check version
- check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
+ self.meta = LeRobotDatasetMetadata(
+ self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
+ )
+ if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
+ episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
+ self.stats = aggregate_stats(episodes_stats)
# Load actual data
- self.download_episodes(download_videos)
- self.hf_dataset = self.load_hf_dataset()
+ try:
+ if force_cache_sync:
+ raise FileNotFoundError
+ assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
+ self.hf_dataset = self.load_hf_dataset()
+ except (AssertionError, FileNotFoundError, NotADirectoryError):
+ self.revision = get_safe_version(self.repo_id, self.revision)
+ self.download_episodes(download_videos)
+ self.hf_dataset = self.load_hf_dataset()
+
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
# Check timestamps
- check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
+ timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
+ episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
+ ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
+ check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
- # Available stats implies all videos have been encoded and dataset is iterable
- self.consolidated = self.meta.stats is not None
-
def push_to_hub(
self,
+ branch: str | None = None,
tags: list | None = None,
license: str | None = "apache-2.0",
+ tag_version: bool = True,
push_videos: bool = True,
private: bool = False,
+ allow_patterns: list[str] | str | None = None,
+ upload_large_folder: bool = False,
**card_kwargs,
) -> None:
- if not self.consolidated:
- logging.warning(
- "You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
- "Consolidating first."
- )
- self.consolidate()
-
ignore_patterns = ["images/"]
if not push_videos:
ignore_patterns.append("videos/")
- create_repo(
+ hub_api = HfApi()
+ hub_api.create_repo(
repo_id=self.repo_id,
private=private,
repo_type="dataset",
exist_ok=True,
)
+ if branch:
+ hub_api.create_branch(
+ repo_id=self.repo_id,
+ branch=branch,
+ revision=self.revision,
+ repo_type="dataset",
+ exist_ok=True,
+ )
- upload_folder(
- repo_id=self.repo_id,
- folder_path=self.root,
- repo_type="dataset",
- ignore_patterns=ignore_patterns,
- )
- card = create_lerobot_dataset_card(
- tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
- )
- card.push_to_hub(repo_id=self.repo_id, repo_type="dataset")
- create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
+ upload_kwargs = {
+ "repo_id": self.repo_id,
+ "folder_path": self.root,
+ "repo_type": "dataset",
+ "revision": branch,
+ "allow_patterns": allow_patterns,
+ "ignore_patterns": ignore_patterns,
+ }
+ if upload_large_folder:
+ hub_api.upload_large_folder(**upload_kwargs)
+ else:
+ hub_api.upload_folder(**upload_kwargs)
+
+ if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
+ card = create_lerobot_dataset_card(
+ tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
+ )
+ card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
+
+ if tag_version:
+ with contextlib.suppress(RevisionNotFoundError):
+ hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
+ hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
def pull_from_repo(
self,
@@ -517,11 +579,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
snapshot_download(
self.repo_id,
repo_type="dataset",
- revision=self.meta._hub_version,
+ revision=self.revision,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
- local_files_only=self.local_files_only,
)
def download_episodes(self, download_videos: bool = True) -> None:
@@ -535,17 +596,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
files = None
ignore_patterns = None if download_videos else "videos/"
if self.episodes is not None:
- files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
- if len(self.meta.video_keys) > 0 and download_videos:
- video_files = [
- str(self.meta.get_video_file_path(ep_idx, vid_key))
- for vid_key in self.meta.video_keys
- for ep_idx in self.episodes
- ]
- files += video_files
+ files = self.get_episodes_file_paths()
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
+ def get_episodes_file_paths(self) -> list[Path]:
+ episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
+ fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
+ if len(self.meta.video_keys) > 0:
+ video_files = [
+ str(self.meta.get_video_file_path(ep_idx, vid_key))
+ for vid_key in self.meta.video_keys
+ for ep_idx in episodes
+ ]
+ fpaths += video_files
+
+ return fpaths
+
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if self.episodes is None:
@@ -557,7 +624,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
+ return hf_dataset
+ def create_hf_dataset(self) -> datasets.Dataset:
+ features = get_hf_features_from_features(self.features)
+ ft_dict = {col: [] for col in features}
+ hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
+
+ # TODO(aliberts): hf_dataset.set_format("torch")
+ hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@property
@@ -624,7 +699,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if key not in self.meta.video_keys
}
- def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
+ def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
@@ -633,9 +708,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {}
for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
- frames = decode_video_frames_torchvision(
- video_path, query_ts, self.tolerance_s, self.video_backend
- )
+ frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
item[vid_key] = frames.squeeze(0)
return item
@@ -654,8 +727,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_indices = None
if self.delta_indices is not None:
- current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
- query_indices, padding = self._get_query_indices(idx, current_ep_idx)
+ query_indices, padding = self._get_query_indices(idx, ep_idx)
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
@@ -691,10 +763,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
- return {
- "size": 0,
- **{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
- }
+ ep_buffer = {}
+ # size and task are special cases that are not in self.features
+ ep_buffer["size"] = 0
+ ep_buffer["task"] = []
+ for key in self.features:
+ ep_buffer[key] = current_ep_idx if key == "episode_index" else []
+ return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
@@ -716,25 +791,35 @@ class LeRobotDataset(torch.utils.data.Dataset):
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
then needs to be called.
"""
- # TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
- # check the dtype and shape matches, etc.
+ # Convert torch to numpy if needed
+ for name in frame:
+ if isinstance(frame[name], torch.Tensor):
+ frame[name] = frame[name].numpy()
+
+ validate_frame(frame, self.features)
if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer()
+ # Automatically add frame_index and timestamp to episode buffer
frame_index = self.episode_buffer["size"]
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
+ # Add frame features to episode_buffer
for key in frame:
- if key not in self.features:
- raise ValueError(key)
+ if key == "task":
+ # Note: we associate the task in natural language to its task index during `save_episode`
+ self.episode_buffer["task"].append(frame["task"])
+ continue
- if self.features[key]["dtype"] not in ["image", "video"]:
- item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
- self.episode_buffer[key].append(item)
- elif self.features[key]["dtype"] in ["image", "video"]:
+ if key not in self.features:
+ raise ValueError(
+ f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
+ )
+
+ if self.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
)
@@ -742,80 +827,95 @@ class LeRobotDataset(torch.utils.data.Dataset):
img_path.parent.mkdir(parents=True, exist_ok=True)
self._save_image(frame[key], img_path)
self.episode_buffer[key].append(str(img_path))
+ else:
+ self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1
- def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
+ def save_episode(self, episode_data: dict | None = None) -> None:
"""
- This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
- disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
- the hub.
+ This will save to disk the current episode in self.episode_buffer.
- Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
- you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
- time for video encoding.
+ Args:
+ episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
+ save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
+ None.
"""
if not episode_data:
episode_buffer = self.episode_buffer
+ validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
+
+ # size and task are special cases that won't be added to hf_dataset
episode_length = episode_buffer.pop("size")
+ tasks = episode_buffer.pop("task")
+ episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
- if episode_index != self.meta.total_episodes:
- # TODO(aliberts): Add option to use existing episode_index
- raise NotImplementedError(
- "You might have manually provided the episode_buffer with an episode_index that doesn't "
- "match the total number of episodes in the dataset. This is not supported for now."
- )
- if episode_length == 0:
- raise ValueError(
- "You must add one or several frames with `add_frame` before calling `add_episode`."
- )
+ episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
+ episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
- task_index = self.meta.get_task_index(task)
+ # Add new tasks to the tasks dictionary
+ for task in episode_tasks:
+ task_index = self.meta.get_task_index(task)
+ if task_index is None:
+ self.meta.add_task(task)
- if not set(episode_buffer.keys()) == set(self.features):
- raise ValueError()
+ # Given tasks in natural language, find their corresponding task indices
+ episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
for key, ft in self.features.items():
- if key == "index":
- episode_buffer[key] = np.arange(
- self.meta.total_frames, self.meta.total_frames + episode_length
- )
- elif key == "episode_index":
- episode_buffer[key] = np.full((episode_length,), episode_index)
- elif key == "task_index":
- episode_buffer[key] = np.full((episode_length,), task_index)
- elif ft["dtype"] in ["image", "video"]:
+ # index, episode_index, task_index are already processed above, and image and video
+ # are processed separately by storing image path and frame info as meta data
+ if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
- elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
- episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
- elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
- episode_buffer[key] = np.stack(episode_buffer[key])
- else:
- raise ValueError(key)
+ episode_buffer[key] = np.stack(episode_buffer[key])
self._wait_image_writer()
self._save_episode_table(episode_buffer, episode_index)
+ ep_stats = compute_episode_stats(episode_buffer, self.features)
- self.meta.save_episode(episode_index, episode_length, task, task_index)
-
- if encode_videos and len(self.meta.video_keys) > 0:
+ if len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index)
for key in self.meta.video_keys:
episode_buffer[key] = video_paths[key]
+ # `meta.save_episode` be executed after encoding the videos
+ self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
+
+ ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
+ ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
+ check_timestamps_sync(
+ episode_buffer["timestamp"],
+ episode_buffer["episode_index"],
+ ep_data_index_np,
+ self.fps,
+ self.tolerance_s,
+ )
+
+ video_files = list(self.root.rglob("*.mp4"))
+ assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
+
+ parquet_files = list(self.root.rglob("*.parquet"))
+ assert len(parquet_files) == self.num_episodes
+
+ # delete images
+ img_dir = self.root / "images"
+ if img_dir.is_dir():
+ shutil.rmtree(self.root / "images")
+
if not episode_data: # Reset the buffer
self.episode_buffer = self.create_episode_buffer()
- self.consolidated = False
-
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
+ ep_dataset = embed_images(ep_dataset)
+ self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
+ self.hf_dataset.set_transform(hf_transform_to_torch)
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
- write_parquet(ep_dataset, ep_data_path)
+ ep_dataset.to_parquet(ep_data_path)
def clear_episode_buffer(self) -> None:
episode_index = self.episode_buffer["episode_index"]
@@ -884,38 +984,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return video_paths
- def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
- self.hf_dataset = self.load_hf_dataset()
- self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
- check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
-
- if len(self.meta.video_keys) > 0:
- self.encode_videos()
- self.meta.write_video_info()
-
- if not keep_image_files:
- img_dir = self.root / "images"
- if img_dir.is_dir():
- shutil.rmtree(self.root / "images")
-
- video_files = list(self.root.rglob("*.mp4"))
- assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
-
- parquet_files = list(self.root.rglob("*.parquet"))
- assert len(parquet_files) == self.num_episodes
-
- if run_compute_stats:
- self.stop_image_writer()
- # TODO(aliberts): refactor stats in save_episodes
- self.meta.stats = compute_stats(self)
- serialized_stats = serialize_dict(self.meta.stats)
- write_json(serialized_stats, self.root / STATS_PATH)
- self.consolidated = True
- else:
- logging.warning(
- "Skipping computation of the dataset statistics, dataset is not fully consolidated."
- )
-
@classmethod
def create(
cls,
@@ -944,7 +1012,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
obj.repo_id = obj.meta.repo_id
obj.root = obj.meta.root
- obj.local_files_only = obj.meta.local_files_only
+ obj.revision = None
obj.tolerance_s = tolerance_s
obj.image_writer = None
@@ -954,19 +1022,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj.create_episode_buffer()
- # This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
- # is used to know when certain operations are need (for instance, computing dataset statistics). In
- # order to be able to push the dataset to the hub, it needs to be consolidated first by calling
- # self.consolidate().
- obj.consolidated = True
-
obj.episodes = None
- obj.hf_dataset = None
+ obj.hf_dataset = obj.create_hf_dataset()
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None
obj.episode_data_index = None
- obj.video_backend = video_backend if video_backend is not None else "pyav"
+ obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
return obj
@@ -986,12 +1048,11 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
delta_timestamps: dict[list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
- local_files_only: bool = False,
video_backend: str | None = None,
):
super().__init__()
self.repo_ids = repo_ids
- self.root = Path(root) if root else LEROBOT_HOME
+ self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
@@ -1004,7 +1065,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
- local_files_only=local_files_only,
video_backend=video_backend,
)
for repo_id in repo_ids
@@ -1032,7 +1092,10 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
- self.stats = aggregate_stats(self._datasets)
+ # TODO(rcadene, aliberts): We should not perform this aggregation for datasets
+ # with multiple robots of different ranges. Instead we should have one normalization
+ # per robot.
+ self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
@property
def repo_id_to_index(self):
diff --git a/lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md b/lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
deleted file mode 100644
index 8fcc8bbe..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
+++ /dev/null
@@ -1,56 +0,0 @@
-## Using / Updating `CODEBASE_VERSION` (for maintainers)
-
-Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
-the datasets with our code, we use a `CODEBASE_VERSION` (defined in
-lerobot/common/datasets/lerobot_dataset.py) variable.
-
-For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
-- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
-- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
-- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
-- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
-- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
-- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5)
-- [v1.6](https://huggingface.co/datasets/lerobot/pusht/tree/v1.6) <-- last version
-- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
-
-Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
-`info.json` metadata.
-
-### Uploading a new dataset
-If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
-compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
-dataset with the current `CODEBASE_VERSION`.
-
-### Updating an existing dataset
-If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
-before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
-intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
-deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
-codebase won't be affected by your change and backward compatibility is maintained.
-
-However, you will need to update the version of ALL the other datasets so that they have the new
-`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
-that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
-dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
-
-```python
-from huggingface_hub import HfApi
-
-from lerobot import available_datasets
-from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
-
-api = HfApi()
-
-for repo_id in available_datasets:
- dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
- branches = [b.name for b in dataset_info.branches]
- if CODEBASE_VERSION in branches:
- print(f"{repo_id} already @{CODEBASE_VERSION}, skipping.")
- continue
- else:
- # Now create a branch named after the new version by branching out from "main"
- # which is expected to be the preceding version
- api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
- print(f"{repo_id} successfully updated @{CODEBASE_VERSION}")
-```
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_cabinet.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_cabinet.txt
deleted file mode 100644
index 8e821d29..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_cabinet.txt
+++ /dev/null
@@ -1,85 +0,0 @@
-https://drive.google.com/file/d/1_SOJkgfP5yZyVjMhTt3nwhvyUjcnlI51/view?usp=drive_link
-https://drive.google.com/file/d/1rmgN8UUzph1qwJnzG1d-uOafodn-gLvb/view?usp=drive_link
-https://drive.google.com/file/d/1NYQ-XxsBVinB6dUoZmVWweT83367P3i2/view?usp=drive_link
-https://drive.google.com/file/d/1oAv_j74zxxCJieMG7r5Vl2BeHK1__3s3/view?usp=drive_link
-https://drive.google.com/file/d/1wFUJQROsrTJt64YRuIeExhFjr2wnK5uu/view?usp=drive_link
-https://drive.google.com/file/d/1KzL3Tt0Le7jVl58XVRUcmigmXjyiuhbK/view?usp=drive_link
-https://drive.google.com/file/d/1qy_YBladeHtianSSGtgAPSHtMin7msvf/view?usp=drive_link
-https://drive.google.com/file/d/1rA_F0V_qL_nyuC_0aBKCisF4-0TIkF2Y/view?usp=drive_link
-https://drive.google.com/file/d/1hw-8qMpz9VgSt62XoASqNRuPECpCwJQP/view?usp=drive_link
-https://drive.google.com/file/d/1BpHOl9rKMzdvNGka6js7C0s40hH6vnDA/view?usp=drive_link
-https://drive.google.com/file/d/1PazhkhiDnJ-OUMyDVDFxEZNKQQqHiNWS/view?usp=drive_link
-https://drive.google.com/file/d/1lZ665R6ATl57dypxH4dGJ2NSt6XYnbuz/view?usp=drive_link
-https://drive.google.com/file/d/1V9HzLaf-tlG15wUzT7KrTDCS_z1vi5NV/view?usp=drive_link
-https://drive.google.com/file/d/1aKauWiXoKqbNwn_2xs4MrmLlaNYlVNmO/view?usp=drive_link
-https://drive.google.com/file/d/1WVD5DFhriO1YmmOgiVHhacR6HWoTPxav/view?usp=drive_link
-https://drive.google.com/file/d/1_X43WgeBAsfkhH9EmpyPki8U9joMeAGC/view?usp=drive_link
-https://drive.google.com/file/d/1t8x0GqWoNKWtnBsB7_D40Z34nL9ak4kf/view?usp=drive_link
-https://drive.google.com/file/d/15V_f26WaKOXjKnq2T3HRWAmtQUi4lbu2/view?usp=drive_link
-https://drive.google.com/file/d/11VFIAsiSDsMOBANgrOcZBpKB9AFWnLy7/view?usp=drive_link
-https://drive.google.com/file/d/1M0NS7vVaxJv3FHnuRYtdwTFYF7We4LxP/view?usp=drive_link
-https://drive.google.com/file/d/1mR0OItTNqFnVLoczcyKYlm6drAy778lO/view?usp=drive_link
-https://drive.google.com/file/d/1NbVFWDQAh-z4JJ4D-Zw6Lps9kdvpqh2j/view?usp=drive_link
-https://drive.google.com/file/d/1JQoZGBzl4W3QG26-n39tefcGN0fDRMbB/view?usp=drive_link
-https://drive.google.com/file/d/1VBjHl-TvZpncopvasIP5G9gecbB2a5f6/view?usp=drive_link
-https://drive.google.com/file/d/1VzSf6zaB21nahm7MsPwroXbJ84NIwq0b/view?usp=drive_link
-https://drive.google.com/file/d/1OtNnfMEydNtZOcivs4k6E_uJSpf8PkGy/view?usp=drive_link
-https://drive.google.com/file/d/14nVvpvsrFr_03Pa_N7MKzwnRwibOUYM6/view?usp=drive_link
-https://drive.google.com/file/d/1M8li6duiO2r3lv_9HhF_XJn0oZUIEK5F/view?usp=drive_link
-https://drive.google.com/file/d/1Cpzea6fO14lxAaNfSBifqoa4ekhCiLD1/view?usp=drive_link
-https://drive.google.com/file/d/1mbxRTm5vlbsY9UJ0jfjM6j9D7kPJjBpG/view?usp=drive_link
-https://drive.google.com/file/d/1RXD1i6IfWsHRlCxVmG04h2h5Ycm_WwZN/view?usp=drive_link
-https://drive.google.com/file/d/1QFqFSwDGOk1BkgGmqgCcc2BRWnJ6R3MA/view?usp=drive_link
-https://drive.google.com/file/d/1bFqWR8DQM0ZUxxtS2bl-RANQvukeFLzp/view?usp=drive_link
-https://drive.google.com/file/d/1pR-rH3yNGoyPdD4hJ6-3lXQ-PstBx9du/view?usp=drive_link
-https://drive.google.com/file/d/107OAwLY-hva9HeQLIK7VCh-ytdDabVjr/view?usp=drive_link
-https://drive.google.com/file/d/1Tpl08QOaSZ37GTO4awFWSdD8wBR9xdlT/view?usp=drive_link
-https://drive.google.com/file/d/1MR164AOM-0S1T6RX8xKTV2IHyaCvpqAW/view?usp=drive_link
-https://drive.google.com/file/d/1_wknJfVnStIhJ82lU_QtcrwahsqYIsr8/view?usp=drive_link
-https://drive.google.com/file/d/1ZuEktWrbYkTx0l5pj3WiZ2CJrfbDOHNo/view?usp=drive_link
-https://drive.google.com/file/d/15G_10hkkkq6yxvyI5NGZirlF-RzduR2F/view?usp=drive_link
-https://drive.google.com/file/d/1DBKxg3ONqh7dhLuX6oh1Yyo2x383V1Hp/view?usp=drive_link
-https://drive.google.com/file/d/1B5iDBkTUr5vopDddV_fHud18SqAHhauS/view?usp=drive_link
-https://drive.google.com/file/d/1acwFV0eenRkki1QcjSKH5xqOtys-P3Pr/view?usp=drive_link
-https://drive.google.com/file/d/1S47BI83xyrh-FKXsvAQqer98Biu_p8XK/view?usp=drive_link
-https://drive.google.com/file/d/1JL6DmBZl3uyq9dyLfgSqtGF06e7E9JwM/view?usp=drive_link
-https://drive.google.com/file/d/16WvRS4Kjog8Pxgr0E3sGGnI01YwL9Uql/view?usp=drive_link
-https://drive.google.com/file/d/12ttGqL33IPWg0-s1SD44rr22M6LiSQBr/view?usp=drive_link
-https://drive.google.com/file/d/1OyZqqnldTU_DliRbr6x0C4a_iWPwIN7j/view?usp=drive_link
-https://drive.google.com/file/d/1oYk00IpLnR9fesLfD15Ebe7nVBffEbcS/view?usp=drive_link
-https://drive.google.com/file/d/1eyE2-MQduCEqCd-5_kl5zsoOEERAzpZD/view?usp=drive_link
-https://drive.google.com/file/d/1ir1Ya-vO0d97pfvbePlUeuKTTRc0qIMU/view?usp=drive_link
-https://drive.google.com/file/d/1hOi-JnqlMt47gVnLZHMTqeojyYVErohl/view?usp=drive_link
-https://drive.google.com/file/d/1NFFw5_PqigQ7xGqsL-MNq2B1r5yAscCf/view?usp=drive_link
-https://drive.google.com/file/d/1uftq1-Zlh8d2sNLWrlVcKYQUwZTD7o24/view?usp=drive_link
-https://drive.google.com/file/d/1-ax19dSLPacVgk000T-m3l4flPcg07pM/view?usp=drive_link
-https://drive.google.com/file/d/126y-lgn86-ZmCz8hooF1THKJGGObw3OB/view?usp=drive_link
-https://drive.google.com/file/d/1JiDniK0VmDIkk92AbBILb8J2Ba59PWML/view?usp=drive_link
-https://drive.google.com/file/d/1kr8nPIRljiU0R4J9SMgj80o1FPQxzu9z/view?usp=drive_link
-https://drive.google.com/file/d/1bbThWRij1pKBh_kFgV8FwK0sXtTHBoLX/view?usp=drive_link
-https://drive.google.com/file/d/1WenzDW6lxk1xkOFm-OiGFfc0ROskAuKU/view?usp=drive_link
-https://drive.google.com/file/d/1MiKRzuzUn1yN-k_6kPJJzIGy7dT-nnsD/view?usp=drive_link
-https://drive.google.com/file/d/17rRg2tcmB-gNhQ0KoZJQmNfyFeoij1jH/view?usp=drive_link
-https://drive.google.com/file/d/11mokBpvrY3ld6sY5WztREtJ1jgqfQV70/view?usp=drive_link
-https://drive.google.com/file/d/1Il_6IOx9NDp1bX_KHizJfBwzTufTmn86/view?usp=drive_link
-https://drive.google.com/file/d/1KswtJGsxJ7eeBDAmNA_aeLjOxcH6MIxa/view?usp=drive_link
-https://drive.google.com/file/d/1gzMhi5uWu4C3Y6WbQ3L-08V96GxTZrRR/view?usp=drive_link
-https://drive.google.com/file/d/1nRQFtaBxfUCYc2W90Qibh0kHCt6YQCfc/view?usp=drive_link
-https://drive.google.com/file/d/1vs-gyW-KheqHbUATwAhA2mmR9GOGw7f_/view?usp=drive_link
-https://drive.google.com/file/d/1MuxzGOA2fgLaHryq82KkQumtuRJGcUOC/view?usp=drive_link
-https://drive.google.com/file/d/1IIwxZnGlqrXLUXqG6yMO0r7uhCvhpk9e/view?usp=drive_link
-https://drive.google.com/file/d/1vE7XPyaFcXP4DtTY5Y9WKIt7zWgmX-Cr/view?usp=drive_link
-https://drive.google.com/file/d/1j-bIV09gr21RC3-x1N_pK4RPLV3fmWKz/view?usp=drive_link
-https://drive.google.com/file/d/1t3nW1rD3S-EL0Oymb5U7ZAj5UMkydkln/view?usp=drive_link
-https://drive.google.com/file/d/14hbfHCdMKtJZ41F9CQReMec2jeRFTOqR/view?usp=drive_link
-https://drive.google.com/file/d/1x-hUyOSne5BW0AzQ3W6_Pf4g5yXQWi9M/view?usp=drive_link
-https://drive.google.com/file/d/1sw9JqRg6E-3P84I3ZhzTrJMu0vuiaMmP/view?usp=drive_link
-https://drive.google.com/file/d/1LuqhQlL4MGZhB_6THmkovRxrlP26BbdC/view?usp=drive_link
-https://drive.google.com/file/d/15C5K6v_lkjnMSmUvVyqHQKwh2N166e7K/view?usp=drive_link
-https://drive.google.com/file/d/1ns_9eSsQeeoZ10nlbkLy8tu0GmJFSnkt/view?usp=drive_link
-https://drive.google.com/file/d/1NpzWJeK6CqjxzjIMYe6aYdX8xGsQwD4o/view?usp=drive_link
-https://drive.google.com/file/d/1NMLezwufKJ9_8xTc9KQThSzVVD71B9Ui/view?usp=drive_link
-https://drive.google.com/file/d/1aa71DCUqs6oXlIxX35jgsmsgm-NlDxPV/view?usp=drive_link
-https://drive.google.com/file/d/1UJzkIZzAL0j-D5YQBnoq7mHvttASy12O/view?usp=drive_link
-https://drive.google.com/file/d/1nPgx36HIJFb7oI94VbRzWjpPP2GANxzG/view?usp=drive_link
-https://drive.google.com/file/d/1NovAP-KVJjqcuvWy3d6G4ptGGAIDqcCx/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_chair.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_chair.txt
deleted file mode 100644
index 497f8d04..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_chair.txt
+++ /dev/null
@@ -1,55 +0,0 @@
-https://drive.google.com/file/d/11M3Ye0r5agMaaicPbVGD0q2Hb3rGklbb/view?usp=drive_link
-https://drive.google.com/file/d/1-tx7SvYYgSvXCvnf_EI2OVdwK-CkFY6S/view?usp=drive_link
-https://drive.google.com/file/d/1EWJunmOpMHaU1hE106wwpbkGYcjQXYAF/view?usp=drive_link
-https://drive.google.com/file/d/1IDn95Z7FSiCckrSENtGV4u3RyFHNQSDY/view?usp=drive_link
-https://drive.google.com/file/d/1CwzvWj1i7QOtqrZvsCZ6BdZaKNDfpN32/view?usp=drive_link
-https://drive.google.com/file/d/1HvAvlhm77nAD3Td24QPSeq8lw-Rl_aOh/view?usp=drive_link
-https://drive.google.com/file/d/1t-suKYOPhXH666RpAYNRp2QU_DOy3AeM/view?usp=drive_link
-https://drive.google.com/file/d/18xpKgWh7RWyjMN5PkLTOo-AxsAadAuRw/view?usp=drive_link
-https://drive.google.com/file/d/1oci5Eto-ztv-AQNz8EnwZveBIhxvk-xJ/view?usp=drive_link
-https://drive.google.com/file/d/1Y-t_4vxdE6NpHO0DLJR8f3mD0Q-Wj5-c/view?usp=drive_link
-https://drive.google.com/file/d/1lylRqbbbB8bgtpsBWMPACmHJreuKmllv/view?usp=drive_link
-https://drive.google.com/file/d/1yliSyMig_NXShWfQx6qyW7Ijf2Y5lFK6/view?usp=drive_link
-https://drive.google.com/file/d/1XXhwJsJbeb7KXAooGvJapnm9bjnGUmxS/view?usp=drive_link
-https://drive.google.com/file/d/1_xs1f3hW2JArKyvfF7UWubWjyROGTLs6/view?usp=drive_link
-https://drive.google.com/file/d/1WVEHpr6EqKCZbkHapQSTXJq4xE4SWFT-/view?usp=drive_link
-https://drive.google.com/file/d/1RqOHv9pEQGvW8NUA7ynffFmG999TL_Az/view?usp=drive_link
-https://drive.google.com/file/d/1cu5AgD2gh-uA3PFJmzxxzNaF3qOSlYY1/view?usp=drive_link
-https://drive.google.com/file/d/1SsrXqiPclNrnYToPZ9Uq-k3y0C4qdHT1/view?usp=drive_link
-https://drive.google.com/file/d/1-J7EXf0vjkLIfSqT8ICEsP6CTjzSLBop/view?usp=drive_link
-https://drive.google.com/file/d/11O7ewUmoZXfyyKjy_6B5RW4DpjICxqBT/view?usp=drive_link
-https://drive.google.com/file/d/1iic44kZoCsjNsfAz2cMstZ9-WQvAhblF/view?usp=drive_link
-https://drive.google.com/file/d/1yLV1lVX-2WnWQldGlnQZ0x7QBuDiVkL3/view?usp=drive_link
-https://drive.google.com/file/d/1Tybp9ru98TTbGn4eyROpUQwDFuALWXmk/view?usp=drive_link
-https://drive.google.com/file/d/13E9OTMiipVJByDs5-J19oWwAz7l94LTN/view?usp=drive_link
-https://drive.google.com/file/d/1EeTpJQdMSliw4JzSMtJ6CyTvVdexjM4M/view?usp=drive_link
-https://drive.google.com/file/d/1NHyNwoFqzeAu-1_PSpq5JfxaiD_xbpn9/view?usp=drive_link
-https://drive.google.com/file/d/1fJcS0phDp4xm_FyGaJ5wr9Pe4KqtHaxD/view?usp=drive_link
-https://drive.google.com/file/d/12AqrLUaewDPEcFRqPZeZFb_TQ0Lfi3At/view?usp=drive_link
-https://drive.google.com/file/d/1x_hd4Qsq1oJS-aj2t3qM7WbbV7KZj05b/view?usp=drive_link
-https://drive.google.com/file/d/14OUSUArmsB068hs6BuEIXQhI1Cyz8Sf0/view?usp=drive_link
-https://drive.google.com/file/d/16zlzh1T5zeUJQnFf382NXkFEKEnDub4O/view?usp=drive_link
-https://drive.google.com/file/d/1IbDltmN-NEFCNtr1TO4ILxEgQ94rtjWv/view?usp=drive_link
-https://drive.google.com/file/d/15gmlf8Gx9455pZ1AlqcCSwh3nDPxMzSr/view?usp=drive_link
-https://drive.google.com/file/d/1qHpRL1oZfIMo_vxnm8qfwQ-7l0BZIVva/view?usp=drive_link
-https://drive.google.com/file/d/1H1xskIgiFZivkYn23rMzH3xePGOh3VTC/view?usp=drive_link
-https://drive.google.com/file/d/1avls6Pv0kYiCMNVknbc1zQsgy64MUDMM/view?usp=drive_link
-https://drive.google.com/file/d/1MmWVgCj5khc8KMIifmt3EzF1o-CtPyyn/view?usp=drive_link
-https://drive.google.com/file/d/1U0kCc_xqW0WNppf4sbnK14euWKdPZtzB/view?usp=drive_link
-https://drive.google.com/file/d/16CaEyQscOuhLj23PEGDTL9DeyNkohkMn/view?usp=drive_link
-https://drive.google.com/file/d/1Iu8uM6UUJ0zW8tvN-9UiOe_4oSNzEutg/view?usp=drive_link
-https://drive.google.com/file/d/1UImqiBaIxCR-1DNJaZhHqeHhaySOtVIr/view?usp=drive_link
-https://drive.google.com/file/d/1VpU2V_leIoRIyv_lAvE7eLHBG8DxCTnp/view?usp=drive_link
-https://drive.google.com/file/d/1_Q8J27OT3Xby7QY6yHvIJauFRWEMxkRm/view?usp=drive_link
-https://drive.google.com/file/d/1bantmVo1L9Xz4tbiNw_a1UC2Z_HPO1wT/view?usp=drive_link
-https://drive.google.com/file/d/1IRIXMJMCBDkBjbaHvAlEiBogSvZ1jK_3/view?usp=drive_link
-https://drive.google.com/file/d/1mAHXKjiFbjwydypW2t5Lv8_H5x6nHegl/view?usp=drive_link
-https://drive.google.com/file/d/1SfyY796fLrBCMY39OcyuxZafqSCRZPZk/view?usp=drive_link
-https://drive.google.com/file/d/1X-44sZ8CcfzIskc0dvSx882o1yFhHaZB/view?usp=drive_link
-https://drive.google.com/file/d/1BOIWCCCk6DLD4Bmvc75ZbbLi9AQm-1ao/view?usp=drive_link
-https://drive.google.com/file/d/1RuyDtRE1kk76sw-wP8vx5SgLoPF3PA_H/view?usp=drive_link
-https://drive.google.com/file/d/1c4eoQiBbGuy3CTAQDUSkd84Ponh1roAQ/view?usp=drive_link
-https://drive.google.com/file/d/19PXB9z4Ljq6dsbf9TqcOrrP5SRbw2Tc_/view?usp=drive_link
-https://drive.google.com/file/d/1nn1VVZVoIXWdYDozR7XHXE4mPLQG80PQ/view?usp=drive_link
-https://drive.google.com/file/d/1MBdFGOKPV8GUhwoSsJ_Ky3qAMLM2Bv3K/view?usp=drive_link
-https://drive.google.com/file/d/1of3k_M-7Nh3I1TndcWedxK4ca9dn8Sc5/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_elevator.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_elevator.txt
deleted file mode 100644
index abb42b55..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_elevator.txt
+++ /dev/null
@@ -1,20 +0,0 @@
-https://drive.google.com/file/d/12ctkOAdkCNGN1JLbZb5ww3XTBn2LFpGI/view?usp=drive_link
-https://drive.google.com/file/d/1G_Vd46_4fq6O64gHHjUbJX5Ld44ZZx0y/view?usp=drive_link
-https://drive.google.com/file/d/1uKgUy73B3xBogQAOUhfZjO0X5qZGsi2c/view?usp=drive_link
-https://drive.google.com/file/d/1fu9cIrfI-fE2LhdGUxbx7-8Ci_PF8Ypm/view?usp=drive_link
-https://drive.google.com/file/d/1Ygk9ZPJzx8xw2A9JF3NHbJ44TqnvSTQR/view?usp=drive_link
-https://drive.google.com/file/d/18m5xPuccNsEB20WPshm3zhxmXc6k63ED/view?usp=drive_link
-https://drive.google.com/file/d/1DiqqxC44rriviRQpqogcv0-EB-Y6nr9g/view?usp=drive_link
-https://drive.google.com/file/d/1qPdaoTVDizJXkfXLioWU7iJ8hqCXSyOQ/view?usp=drive_link
-https://drive.google.com/file/d/1Fj9kIA_mG7f67WFfACJEaZ7izcHG7vUm/view?usp=drive_link
-https://drive.google.com/file/d/1WpYehZnI2P7dUdJPfkE-ij1rqCnjZEbB/view?usp=drive_link
-https://drive.google.com/file/d/1_zwWkT4jPyzB38STWb6whlzsPzXmfA9r/view?usp=drive_link
-https://drive.google.com/file/d/1U6-J4I_fPlSFFGfhZPxS5_YzKXwXIZYp/view?usp=drive_link
-https://drive.google.com/file/d/1pRhxxcTfZp5tQo_EScvJUwfc3amiS6Vk/view?usp=drive_link
-https://drive.google.com/file/d/1lWLntqra83RlYU_gN7Vostnfydf6gutd/view?usp=drive_link
-https://drive.google.com/file/d/1vIBKo0x-NYEHV1FvRpco1lQMpRdAWAIL/view?usp=drive_link
-https://drive.google.com/file/d/1pdrLV3JTQou_XH0Aap61Ssf60iVKm1jJ/view?usp=drive_link
-https://drive.google.com/file/d/1QTsLoQ7SwmKdQHjBGVDaR2uTwfFwtrOf/view?usp=drive_link
-https://drive.google.com/file/d/1Gytai8M_12J36GY6L_TulEcOC-035jwS/view?usp=drive_link
-https://drive.google.com/file/d/14LJudNc629NT-i8xreXtzl27ce_DxOFJ/view?usp=drive_link
-https://drive.google.com/file/d/1sBvPCODbzxGAI0S3lgN5cSG9Go3lRi00/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_shrimp.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_shrimp.txt
deleted file mode 100644
index a6d76bd7..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_shrimp.txt
+++ /dev/null
@@ -1,18 +0,0 @@
-https://drive.google.com/file/d/1MJn9GbC8p9lN4gC9KDMLEkTkP_gGpXj0/view?usp=drive_link
-https://drive.google.com/file/d/1-4LXgjl7ZCOgp-8GCJmFRD8OeqN5Jf7-/view?usp=drive_link
-https://drive.google.com/file/d/1Ho06Ce0SPbqU3juaMxNUwAt3zCRLGC8W/view?usp=drive_link
-https://drive.google.com/file/d/1ivHoj7_7olBSxH-Y8kqXEW7ttITK-45j/view?usp=drive_link
-https://drive.google.com/file/d/1qjY4hM_IvZ8cq2II_n9MeJbvyeuN4oBP/view?usp=drive_link
-https://drive.google.com/file/d/1rKVhO_f92-7sw13T8hTVrza3B9oAVgoy/view?usp=drive_link
-https://drive.google.com/file/d/1pcLPHO8fBkc1-CRa88tyQtEueE4xiXNi/view?usp=drive_link
-https://drive.google.com/file/d/1Vev_chCsIeEdvQ8poEYNsOJFGy_QU8kZ/view?usp=drive_link
-https://drive.google.com/file/d/1l5G4zpRkxSLCQjvGPYSN4zfCvVRQuzMz/view?usp=drive_link
-https://drive.google.com/file/d/14vgthE1eoakXkr2-DRw50E6lAqYOiUuE/view?usp=drive_link
-https://drive.google.com/file/d/17nPSmKKmgQ2B7zkzWrZYiLM3RBuFod82/view?usp=drive_link
-https://drive.google.com/file/d/1QcDsxplVvb_ID9BVrihl5FvlC-j7waXi/view?usp=drive_link
-https://drive.google.com/file/d/18pEejBpI-eEVaWAAjBCyC0vgbX3T1Esj/view?usp=drive_link
-https://drive.google.com/file/d/1H8eH6_IRODtEFT6WoM77ltR5OoOrqXmI/view?usp=drive_link
-https://drive.google.com/file/d/1IWlpFRZhoxyG4nS13CWK4leZVk5wbNx4/view?usp=drive_link
-https://drive.google.com/file/d/1PbZA8_OCGmMLxNP9xbkLRSChniL4uGxl/view?usp=drive_link
-https://drive.google.com/file/d/1p9XAdmG2f_WeflNO4DIJ_tr1rK6M9B4B/view?usp=drive_link
-https://drive.google.com/file/d/1nS59Et1cNAvKo3Y4SeSGRuZD5TvBbCF3/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_wash_pan.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_wash_pan.txt
deleted file mode 100644
index 5e3732bd..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_wash_pan.txt
+++ /dev/null
@@ -1 +0,0 @@
-https://drive.google.com/drive/folders/1S8eFg98IaGAIKVZ8QFWG1bx4mHa-O204
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_wipe_wine.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_wipe_wine.txt
deleted file mode 100644
index 17a13f1a..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_wipe_wine.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-https://drive.google.com/drive/folders/1tC_g1AJ8lglBLY-fjsQrG6DMBa3Ucp-0
-https://drive.google.com/file/d/1fG_Yi2MJrFjiUVN3XoiWXLtTxHlwwaDv/view?usp=drive_link
-https://drive.google.com/file/d/1WX32VWfzzX3Blmd06DRxLwFbMJfVe7P4/view?usp=drive_link
-https://drive.google.com/file/d/18onsX3vXg3xkFwP5bVUCjdV4n9TRn0C9/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_insertion_human.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_insertion_human.txt
deleted file mode 100644
index 19bb7114..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_insertion_human.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF
-https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link
-https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_insertion_scripted.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_insertion_scripted.txt
deleted file mode 100644
index fc80579b..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_insertion_scripted.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N
-https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link
-https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_transfer_cube_human.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_transfer_cube_human.txt
deleted file mode 100644
index f5161ea2..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_transfer_cube_human.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo
-https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link
-https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_transfer_cube_scripted.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_transfer_cube_scripted.txt
deleted file mode 100644
index d3a5b414..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_transfer_cube_scripted.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj
-https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link
-https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_battery.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_battery.txt
deleted file mode 100644
index a3613eb7..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_battery.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-https://drive.google.com/drive/folders/19qS_n7vKgDcPeTMnvDHQ5-n73xEbJz5D
-https://drive.google.com/file/d/1oC31By0A2bsBeHyUwBdQw1z4ng6yi9Za/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_candy.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_candy.txt
deleted file mode 100644
index a39bde56..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_candy.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-https://drive.google.com/drive/folders/1m5rQ6UVH8Q9RQp_6c0CxkQ88-L-ScO7q
-https://drive.google.com/file/d/1wHz2qcmwcVG0C0CZ9MjQDQcmj4OY9_a3/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_coffee.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_coffee.txt
deleted file mode 100644
index 3f4acbd0..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_coffee.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-https://drive.google.com/drive/folders/1seQGay470nGQ-knBI5TjsTr8iL9Qws5q
-https://drive.google.com/file/d/1T89hSX5U99wLGvGTE7yUBaQPOpyj6Sai/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_coffee_new.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_coffee_new.txt
deleted file mode 100644
index 06667fef..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_coffee_new.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-https://drive.google.com/drive/folders/1t3eDc5Rg0DveyRe8oTm6Dia_FYU5mXyf
-https://drive.google.com/file/d/1TXFaduTakvS0ZWJqKCX-HIvYglum_5CY/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_cups_open.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_cups_open.txt
deleted file mode 100644
index 2cde5fa0..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_cups_open.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-https://drive.google.com/drive/folders/1Z9X3DNzd6LS0FFjQemNUMoMA5yk5VQOh
-https://drive.google.com/file/d/1Wlyc0vTkjXuWB6zbaVOWhEfD7BmPgUV_/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_fork_pick_up.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_fork_pick_up.txt
deleted file mode 100644
index 92b0d474..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_fork_pick_up.txt
+++ /dev/null
@@ -1,53 +0,0 @@
-https://drive.google.com/drive/folders/1DYgB4ifX4uIid9m9jnC0Zdz8Nf7ZC0fc
-https://drive.google.com/file/d/1Eb-NRNk_FmVleCbU_Ng5Y4dfcjTKN7Rv/view?usp=drive_link
-https://drive.google.com/file/d/1dkhjEADakT-44l9jf-nK4x89kr4yG_qb/view?usp=drive_link
-https://drive.google.com/file/d/14hDhgcZkVqNExGb4tIXpSjMshhqZETch/view?usp=drive_link
-https://drive.google.com/file/d/1zVMEHpHbuNyP5A_lYU7RPSLB-4V0yfZw/view?usp=drive_link
-https://drive.google.com/file/d/1JtgDjBvy7FnRpFzrx_foC3quorYQFAR-/view?usp=drive_link
-https://drive.google.com/file/d/1EHdneB6F-PP0dQlX8qPaXbxmKoBy_YwO/view?usp=drive_link
-https://drive.google.com/file/d/17Z0jjVBy1OPKREPu77_n_rQzorDiapji/view?usp=drive_link
-https://drive.google.com/file/d/1F4i23qPJ_qTf5jWjfLo4ARGJChznYWt3/view?usp=drive_link
-https://drive.google.com/file/d/1kZtXWM3uS0-rLblydBfJ0mMcVnMMXw9w/view?usp=drive_link
-https://drive.google.com/file/d/1mNODox87xFfY5Z_o5mcLsr8SHb39jDik/view?usp=drive_link
-https://drive.google.com/file/d/1Ob44VdmEUA93FKDECiRb5Ogz2xQg5IWp/view?usp=drive_link
-https://drive.google.com/file/d/1fdQLdjj3Cwv33R1wZhfrLz9Del8mqgHb/view?usp=drive_link
-https://drive.google.com/file/d/1Yu3L3ft21zP__XL8pCfhb788ZleuW1n5/view?usp=drive_link
-https://drive.google.com/file/d/1ozBBWXVZ9hXDh9ooHUNroHdYm8UDqnhJ/view?usp=drive_link
-https://drive.google.com/file/d/1o0TGqvfWw_Lunxb5ubKDS21Lr_WC0h75/view?usp=drive_link
-https://drive.google.com/file/d/1jZnd5eP5L6BH5l98BPN6OnoQx3fu8e9n/view?usp=drive_link
-https://drive.google.com/file/d/1S5sYbz8wcLYp0V67v13i4PRcBxodn4Hg/view?usp=drive_link
-https://drive.google.com/file/d/1rFeg_x6ftJYwPtBv34D3h2L2cpDLeR4G/view?usp=drive_link
-https://drive.google.com/file/d/1GvS3lcm4o6nm_scUk0XxKeVFNmzjucDZ/view?usp=drive_link
-https://drive.google.com/file/d/1-9i0riphC7NhhDahcQfD1QoBXP5gF90A/view?usp=drive_link
-https://drive.google.com/file/d/15p_IqGsMbKuvzMS872THAZr-3SBtb1Fr/view?usp=drive_link
-https://drive.google.com/file/d/1ToyYcBfJL8gbQn0q_59zPLsFmm7dmMJo/view?usp=drive_link
-https://drive.google.com/file/d/1e_7PNH7CYafE4pAebP7ZdI7XFbmEcy_i/view?usp=drive_link
-https://drive.google.com/file/d/1JoabvGVsIQdug2xOhUIhetEIyDM91y_Y/view?usp=drive_link
-https://drive.google.com/file/d/1kOMw1y0lmnVaCjwZICfzCsx6e0Z8MNGR/view?usp=drive_link
-https://drive.google.com/file/d/16it_wd1JOevUQTK2_CvF_pBACTgpIPgM/view?usp=drive_link
-https://drive.google.com/file/d/1IRcCj9HnJSfbyMgr5XEERGlEnWeZQwOc/view?usp=drive_link
-https://drive.google.com/file/d/1Z2dIJfq_S3liGmPN9Rphvkmucnmw7tlb/view?usp=drive_link
-https://drive.google.com/file/d/1J3NoAjzndGx9yNyaBOJHdNny1epzUoBt/view?usp=drive_link
-https://drive.google.com/file/d/18nOvxV1k8FSmBrhT4TPo2sKKSZXougyx/view?usp=drive_link
-https://drive.google.com/file/d/1CT8FxclafFMjSd7gCWVw3VSeryeiF04i/view?usp=drive_link
-https://drive.google.com/file/d/16M9KVqQMFfSsXfypK0bocFft8Nz3j2Rt/view?usp=drive_link
-https://drive.google.com/file/d/18QPVkw6bj6HW8LTPrQLWrrUX4R6RcF42/view?usp=drive_link
-https://drive.google.com/file/d/1hQTVtA5hBTE_StXpJafTZJ3tgt2VQQ_t/view?usp=drive_link
-https://drive.google.com/file/d/1Dn-d5g69H6EgAWgsFdrcbJKtz7ySsCQ8/view?usp=drive_link
-https://drive.google.com/file/d/13hMr16483P7ALYv73yMRUN37fJdVQM62/view?usp=drive_link
-https://drive.google.com/file/d/1848yN3XMN5zJMEgApt6KzrWgfRPfimtv/view?usp=drive_link
-https://drive.google.com/file/d/1oAD9kSnS0fTgj-CjD4u9VdZ5X67IOIMa/view?usp=drive_link
-https://drive.google.com/file/d/1ilzIWLCCG5b_KgF5s0wdN2I5-lFNpwC1/view?usp=drive_link
-https://drive.google.com/file/d/1rjsT2YBjnidxod1s9s-myAYz8boHr-WB/view?usp=drive_link
-https://drive.google.com/file/d/18Gg48HTub15bd8qzbhiCUufbVy0fbN5G/view?usp=drive_link
-https://drive.google.com/file/d/1WsSnQSqmMTVSRwrhT1Y-v782My2zcjLm/view?usp=drive_link
-https://drive.google.com/file/d/1ea9ZCvoyc-xqiFXgeDcA_mOWsw7VUuoi/view?usp=drive_link
-https://drive.google.com/file/d/1wv1v3-XhPgbNzp62BXbJTDzMPu2tlDUc/view?usp=drive_link
-https://drive.google.com/file/d/18-ikzt8LoZ83Gi3goKCELs4U4z8hrRoF/view?usp=drive_link
-https://drive.google.com/file/d/16Bjhp7JNCXkGuLvyNcZowAx3W-Y-15DV/view?usp=drive_link
-https://drive.google.com/file/d/1Gc-KRI-xwcp1fMR55ugbrLg_5y3SPde-/view?usp=drive_link
-https://drive.google.com/file/d/1oP72Q386Z4Sy5MMm-t5yNogIe5Van_9k/view?usp=drive_link
-https://drive.google.com/file/d/112T90eDUDVH-SyOV7UnZl5bscAH2hcfq/view?usp=drive_link
-https://drive.google.com/file/d/1y-uKOesRRhjgDtFbG_j65f4SGg0v8XDg/view?usp=drive_link
-https://drive.google.com/file/d/1LOP05OagoI3km-ZKQBrS204A85UVk7Ok/view?usp=drive_link
-https://drive.google.com/file/d/1QkHQKgasVzWsmdPvkXgGhWyQ84d93_Az/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_pingpong_test.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_pingpong_test.txt
deleted file mode 100644
index c622def6..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_pingpong_test.txt
+++ /dev/null
@@ -1 +0,0 @@
-https://drive.google.com/drive/folders/1Ut2cv6o6Pkfgg46DgwVUM7Z5PkNG8eJ-
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_pro_pencil.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_pro_pencil.txt
deleted file mode 100644
index bdfc447f..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_pro_pencil.txt
+++ /dev/null
@@ -1 +0,0 @@
-https://drive.google.com/drive/folders/1FqxPV0PgvgIu8XFjtvZSPSExuNcxVVAY
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_screw_driver.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_screw_driver.txt
deleted file mode 100644
index fe5548fd..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_screw_driver.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-https://drive.google.com/drive/folders/1SKtG0ct9q0nVdYssJNMWSOjikcXliT58
-https://drive.google.com/file/d/1nchD21O30B3i3LDoqramo1zgW5YvpJIN/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_tape.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_tape.txt
deleted file mode 100644
index 46d95479..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_tape.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-https://drive.google.com/drive/folders/1_4DHf2cma0xsChLQFghwigX6Ukti5-zQ
-https://drive.google.com/file/d/1_8vS4hDNDgUQY-SmekrNaa7dF67QJYU-/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_thread_velcro.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_thread_velcro.txt
deleted file mode 100644
index 46d95479..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_thread_velcro.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-https://drive.google.com/drive/folders/1_4DHf2cma0xsChLQFghwigX6Ukti5-zQ
-https://drive.google.com/file/d/1_8vS4hDNDgUQY-SmekrNaa7dF67QJYU-/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_towel.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_towel.txt
deleted file mode 100644
index 19288fa5..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_towel.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-https://drive.google.com/drive/folders/1fAD7vkyTGTFB_nGXIKofCU1U05oE3MFv
-https://drive.google.com/file/d/1XzyQ2B6LLvcurIonOpEu4nij2qwNWshH/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_vinh_cup.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_vinh_cup.txt
deleted file mode 100644
index 65ec35c4..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_vinh_cup.txt
+++ /dev/null
@@ -1,53 +0,0 @@
-https://drive.google.com/drive/folders/13EQsVsnxT86K20QAoyE_YpsFbQ7fZQdu
-https://drive.google.com/file/d/1-W_JHghZG65FNTVhw1SXhtQrazdLL3Ue/view?usp=drive_link
-https://drive.google.com/file/d/1VwRJgdWUo-2nQaNM7Bs77-fsm8iwUxEo/view?usp=drive_link
-https://drive.google.com/file/d/1wFzGRo5iYA13WLi6IV1ry64RyahQBFio/view?usp=drive_link
-https://drive.google.com/file/d/1IKtQzQ-n-UTv64hYpReu2R4cqUvmNQqD/view?usp=drive_link
-https://drive.google.com/file/d/1GicVci9OiuuZZH79i5Mg7AtWod94MzwT/view?usp=drive_link
-https://drive.google.com/file/d/1JVnIoR7EIQp70T4eAf9RX65JcTrzsjQc/view?usp=drive_link
-https://drive.google.com/file/d/1W2xr4h23ucjPrc-mBEeqnACsfaImpc0p/view?usp=drive_link
-https://drive.google.com/file/d/10xj_0V7A07o3uCa7v5omUrTC0YlPW8H3/view?usp=drive_link
-https://drive.google.com/file/d/1FOc3EMaCy8Mb0_a7PuXLAwKwvxkbKmwU/view?usp=drive_link
-https://drive.google.com/file/d/143PgDXBcf2GQ0Q07ZPMVMfBgZDd5sLJG/view?usp=drive_link
-https://drive.google.com/file/d/1pE5Tyj0LlGbGWvUzuhixp86Ibu55Ez3I/view?usp=drive_link
-https://drive.google.com/file/d/141668b1VzX80ncrVJPzhkoAeIFB4MEK9/view?usp=drive_link
-https://drive.google.com/file/d/1bw12lo37p1ZvRvErHsll7cEYi2OxscvZ/view?usp=drive_link
-https://drive.google.com/file/d/1zfnMFvbgBjl6SzYhksbaOzfbwLrCN6tb/view?usp=drive_link
-https://drive.google.com/file/d/1-GIszA6mUJMaNB-tdh9r9skc77SWA0VX/view?usp=drive_link
-https://drive.google.com/file/d/1fTB0zWFYU6zh4IIUFT2zX_OkwYqmElwY/view?usp=drive_link
-https://drive.google.com/file/d/1gPIPNKGmrO9c7gKF7SP0SuUYbIBBq8z1/view?usp=drive_link
-https://drive.google.com/file/d/12JeJ-dQd5lYyn6PlDOGdE-ChVeiZ-Uv0/view?usp=drive_link
-https://drive.google.com/file/d/100_20cgCqerU6qoh3TfTbwLy9mlDAFEG/view?usp=drive_link
-https://drive.google.com/file/d/111oAGJ76ku_pYgbBoIdZAC1_XEQcPI__/view?usp=drive_link
-https://drive.google.com/file/d/1UhC8L-354ZQ2gblPFGI35EMsVwfpuKa0/view?usp=drive_link
-https://drive.google.com/file/d/1sIXQSgUR_xdrNtGrL6QGBnkLMKErsIp1/view?usp=drive_link
-https://drive.google.com/file/d/16Ax77bDSIXnsn4GFL8XYKKT1P6bPpfMd/view?usp=drive_link
-https://drive.google.com/file/d/1pgRVYwwVIsWq_qsWqZpe1UBzZfF5Fa9D/view?usp=drive_link
-https://drive.google.com/file/d/1jtimaZkWsY1P5gC2bbS64H_WCUU7HXN2/view?usp=drive_link
-https://drive.google.com/file/d/1N6Bh02P-RiTEgtx1YH1Db_X3TGpP-X_r/view?usp=drive_link
-https://drive.google.com/file/d/14Fy8EwJ8d9Vh97Yt1VOvUChSCrfIjBij/view?usp=drive_link
-https://drive.google.com/file/d/1IRuv42dvIMPuKhcMZmuXaBjJ-lPFOmQd/view?usp=drive_link
-https://drive.google.com/file/d/16XWzNY2D8ucVVn5geBgsVdhm3ppO4que/view?usp=drive_link
-https://drive.google.com/file/d/1xsVOoQgthK_L_SDrmq_JvQgUpAvPEAY8/view?usp=drive_link
-https://drive.google.com/file/d/1bZbw66DyEMvnJnzkdUUNbKjvNKg8KFYM/view?usp=drive_link
-https://drive.google.com/file/d/1CyTVkdrNGGpouCXr4CfhKbMzE6Ah3oo3/view?usp=drive_link
-https://drive.google.com/file/d/1hDRyeM-XEDpHXpptbT8LvNnlQUR3PWOh/view?usp=drive_link
-https://drive.google.com/file/d/1XhHWxbra8Iy5irQZ83IvxwaJqHq9x4s1/view?usp=drive_link
-https://drive.google.com/file/d/1haZcn6aM1o4JlmP9tJj3x2enrxiPaDSD/view?usp=drive_link
-https://drive.google.com/file/d/1ypDyuUTbljaBZ34f-t7lj3O_0bRmyX2n/view?usp=drive_link
-https://drive.google.com/file/d/1ILEEZo_tA9_ChIAprr2mPaNVKZi5vXsO/view?usp=drive_link
-https://drive.google.com/file/d/1U7nVYFaGE8vVTfLCW33D74xOjDcqfgyJ/view?usp=drive_link
-https://drive.google.com/file/d/1rZ93_rmCov5SMDxPkfM3qthcRELZrQX6/view?usp=drive_link
-https://drive.google.com/file/d/1mYO1b_csddtyE3qT6cwLiw-m2w2_1Lxh/view?usp=drive_link
-https://drive.google.com/file/d/1xz7Q5x2jikY8wJQjMRQpRws6AnfWlHm5/view?usp=drive_link
-https://drive.google.com/file/d/1OO8GaO-0FrSZRd1kxMYwBmubyiLOWnbl/view?usp=drive_link
-https://drive.google.com/file/d/1EXn4NVDmf-4_HCy34mYwT-vwK2CFI9ev/view?usp=drive_link
-https://drive.google.com/file/d/10hH70XhXRL9C5SnAG4toHtfHqfJUJo4H/view?usp=drive_link
-https://drive.google.com/file/d/18tiBcxea0guUai4lwsXQvt0q2LZ8ZnnJ/view?usp=drive_link
-https://drive.google.com/file/d/1Q8R8qv37vk5PQ5kQ2ibx6BFLOySD0VpX/view?usp=drive_link
-https://drive.google.com/file/d/17aNriHzjhdibCyuUjQoMFZqjybJZtggG/view?usp=drive_link
-https://drive.google.com/file/d/1LVjEYHSdeKm6CotU1QguIeNEPaIaFl_1/view?usp=drive_link
-https://drive.google.com/file/d/1ufAhE_EkgJ85slg2EW8aW_grOzE_Lmxd/view?usp=drive_link
-https://drive.google.com/file/d/1wtzLtXrkw9eXRGESTPIOlpl1tInu-b2m/view?usp=drive_link
-https://drive.google.com/file/d/1Mk5qvVtD_QHwGOUApRq76TUw2T5THu6f/view?usp=drive_link
-https://drive.google.com/file/d/1y1WQ3hboWVJ68KEYQQ3OhreGuaUpSgwc/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_vinh_cup_left.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_vinh_cup_left.txt
deleted file mode 100644
index 8823a9b5..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_vinh_cup_left.txt
+++ /dev/null
@@ -1,52 +0,0 @@
-https://drive.google.com/drive/folders/1dxWh6YFZUDt6qXIoxgD9bla3CiFjZ11C
-https://drive.google.com/file/d/1hNBJN00SCAlOl0ZEgm7RRGbAGDjyBs0p/view?usp=drive_link
-https://drive.google.com/file/d/17He0CVwXGeoMmXg4SHKo-osNn7YPKVL7/view?usp=drive_link
-https://drive.google.com/file/d/1laNKUVID1x2CV6a2O2WQjwFewKu4lidL/view?usp=drive_link
-https://drive.google.com/file/d/1pNf36xbZJGRArYLmNAvRj5y6CoqdC6kB/view?usp=drive_link
-https://drive.google.com/file/d/1_4E1-y3JXk5I0ebycLYM70YDPK9g52gZ/view?usp=drive_link
-https://drive.google.com/file/d/1PHfzhGPdbolKyOpS3FnR2w7Q8zUlJXSk/view?usp=drive_link
-https://drive.google.com/file/d/17ls2PPN-Pi3tEuK059cwV2_iDT8aGhOO/view?usp=drive_link
-https://drive.google.com/file/d/1LWsg6PmCT00Kv_N_slrmcwKmQPGoBT3k/view?usp=drive_link
-https://drive.google.com/file/d/12LckrchoHTUVH7rxi8J7zD9dA19GXvoW/view?usp=drive_link
-https://drive.google.com/file/d/1VqrJKjAIkj5gtFXL69grdSeu9CyaqnSw/view?usp=drive_link
-https://drive.google.com/file/d/1g5rQYDBZvW-kUtYPeyF3qmd53v6k7kXu/view?usp=drive_link
-https://drive.google.com/file/d/10kUgaSJ0TS7teaG83G3Rf_DG4XGrBt6A/view?usp=drive_link
-https://drive.google.com/file/d/1je9XmneZQZvTma5adMJICUPDovW3ppei/view?usp=drive_link
-https://drive.google.com/file/d/1v28r6bedwZGbUPVVTVImXhK-42XdtGfj/view?usp=drive_link
-https://drive.google.com/file/d/1-TEEx9sGVvzMMaNXYfQMtY2JJ6cvl0dT/view?usp=drive_link
-https://drive.google.com/file/d/1YdBKdJFP9rJWBUX7qrOYL_gfUA8o6J9M/view?usp=drive_link
-https://drive.google.com/file/d/1X9vffwQHNUSKLXr2RlYNtbWDIFCIDfdF/view?usp=drive_link
-https://drive.google.com/file/d/11hqesqa5kvEe5FABUnZRcvmOhR373cYM/view?usp=drive_link
-https://drive.google.com/file/d/1ltTTECjEcbQPgS3UPRgMzaE2x9n6H7dC/view?usp=drive_link
-https://drive.google.com/file/d/1Zxqfa29JdwT-bfMpivi6IG2vz34d21dD/view?usp=drive_link
-https://drive.google.com/file/d/11LQlVxS5hz494dYUJ_PNRPx2NHIJbQns/view?usp=drive_link
-https://drive.google.com/file/d/1i1JhNtnZpO_E8rAv8gxBP3ZTZRvcvsZi/view?usp=drive_link
-https://drive.google.com/file/d/11jOXAr2EULUO4Qkm748634lg4UUFho5U/view?usp=drive_link
-https://drive.google.com/file/d/1rj67wur8DdB_Pipwx24bY43xu4X1eQ5e/view?usp=drive_link
-https://drive.google.com/file/d/15ZTm6lO6f_JQy_4SNfrOu3iPYn1Ro8mh/view?usp=drive_link
-https://drive.google.com/file/d/1q4gBtqWPJtCwXEvknGgN0WHGp7Vfn1b9/view?usp=drive_link
-https://drive.google.com/file/d/1t17keyre47AYqm8GgXiQ7EcvcUkeSiDQ/view?usp=drive_link
-https://drive.google.com/file/d/1OYUPGxtZgOF86Ng_BEOTXm_XOYpuQPsO/view?usp=drive_link
-https://drive.google.com/file/d/1cBjbGHi3dwWHtx6r9EQJi0JT_CE3LuHt/view?usp=drive_link
-https://drive.google.com/file/d/14qaMyF0mcbCB-fCYKNyo5_2NahSC6D5u/view?usp=drive_link
-https://drive.google.com/file/d/12FgX86eA7Y5co9ULBVK80XMsiKQSs-Ri/view?usp=drive_link
-https://drive.google.com/file/d/1yvoHWidf-jdBVw6qCCXOFfkVwKj_2hPk/view?usp=drive_link
-https://drive.google.com/file/d/1a2SugsSDlC8UtUrFzp-_KAwyZckQOvdQ/view?usp=drive_link
-https://drive.google.com/file/d/1l8pILBFSAosypWJMza2K09Vm7rug9axm/view?usp=drive_link
-https://drive.google.com/file/d/1hfPQ8dBCk97PnOhq6_MIISm3IEzcOxJG/view?usp=drive_link
-https://drive.google.com/file/d/1PPAUwlJCFKpms8cqF_k1v2_fCgDBOc3S/view?usp=drive_link
-https://drive.google.com/file/d/1lVKQZeqFfK3amEmLuFhYLUFQ2eyE8rOW/view?usp=drive_link
-https://drive.google.com/file/d/1K9iPMLfDowcIFoyzpvgn88dQ6x6kVwNG/view?usp=drive_link
-https://drive.google.com/file/d/1PNvMqG9tL7QxeLaYBGHiWYR6SYb5iIct/view?usp=drive_link
-https://drive.google.com/file/d/1xkRtzbvIkUsylx9hrFLGQsJn0h1EYu-5/view?usp=drive_link
-https://drive.google.com/file/d/1nxMRrJlSayjDIfr5CmHO1NzAw3COhsLi/view?usp=drive_link
-https://drive.google.com/file/d/1Qs3WEyMGrmagiHIkkFEueWNnJhkUeR1s/view?usp=drive_link
-https://drive.google.com/file/d/1D-G2_Q0SS3M8zyJbg_XzkF2ANPw1HTuX/view?usp=drive_link
-https://drive.google.com/file/d/1mdmJsDGO-YtJAOF_yPKl6lq4PJOIbQhT/view?usp=drive_link
-https://drive.google.com/file/d/11m9bwfop_sPmnQr_8amB6EEsrbAeG_z5/view?usp=drive_link
-https://drive.google.com/file/d/19tyYt5FMn5kru0g9o2nMJhKPnsDqkIZv/view?usp=drive_link
-https://drive.google.com/file/d/1XvTpUdsVTZ-vydvdYYmynbma--HfUGSl/view?usp=drive_link
-https://drive.google.com/file/d/1MO3hFu68J6NohTzr9aB_fY02VA6QSOqj/view?usp=drive_link
-https://drive.google.com/file/d/1Lh-UjwAk__04YOTWINF_QGVU8SjetVaY/view?usp=drive_link
-https://drive.google.com/file/d/1jkSOUwZV5GJ7rZlVeErjcu0DBQs8Np0d/view?usp=drive_link
-https://drive.google.com/file/d/1VIN1eLI-93WrVQwCjsv6XQr353DqqBYA/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_ziploc_slide.txt b/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_ziploc_slide.txt
deleted file mode 100644
index 5db6ed95..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_ziploc_slide.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-https://drive.google.com/drive/folders/1EgKar7rWBmTIRmeJYZciSwjZx3uP2mHO
-https://drive.google.com/file/d/12eYWQO15atK2hBjXhynPJd9MKAj_42pz/view?usp=drive_link
-https://drive.google.com/file/d/1Ul4oEeICJDjgfYTl4H1uaisTzVYIM6wd/view?usp=drive_link
-https://drive.google.com/file/d/1WSF-OG8lKSe2wVYCv5D1aJNipxpgddk-/view?usp=drive_link
-https://drive.google.com/file/d/1_ppD5j5sFh26aWW0JmhLzJMeNB-lCArk/view?usp=drive_link
-https://drive.google.com/file/d/1WUp846dgWXYhu4oJfhHxiU6YL_7N6s4W/view?usp=drive_link
-https://drive.google.com/file/d/1HRZNAIoAQw_uYiPwnBvtBioQoqiqoXdA/view?usp=drive_link
-https://drive.google.com/file/d/1hedGq-QDMnIn8GlXXBC3GiEJ_Y-LTxyt/view?usp=drive_link
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py b/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py
deleted file mode 100644
index 33b4c974..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py
+++ /dev/null
@@ -1,634 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
-
-Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.
-"""
-
-from __future__ import annotations
-
-import math
-import numbers
-import os
-from functools import cached_property
-
-import numcodecs
-import numpy as np
-import zarr
-
-
-def check_chunks_compatible(chunks: tuple, shape: tuple):
- assert len(shape) == len(chunks)
- for c in chunks:
- assert isinstance(c, numbers.Integral)
- assert c > 0
-
-
-def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"):
- old_arr = group[name]
- if chunks is None:
- chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks
- check_chunks_compatible(chunks, old_arr.shape)
-
- if compressor is None:
- compressor = old_arr.compressor
-
- if (chunks == old_arr.chunks) and (compressor == old_arr.compressor):
- # no change
- return old_arr
-
- # rechunk recompress
- group.move(name, tmp_key)
- old_arr = group[tmp_key]
- n_copied, n_skipped, n_bytes_copied = zarr.copy(
- source=old_arr,
- dest=group,
- name=name,
- chunks=chunks,
- compressor=compressor,
- )
- del group[tmp_key]
- arr = group[name]
- return arr
-
-
-def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=None):
- """
- Common shapes
- T,D
- T,N,D
- T,H,W,C
- T,N,H,W,C
- """
- itemsize = np.dtype(dtype).itemsize
- # reversed
- rshape = list(shape[::-1])
- if max_chunk_length is not None:
- rshape[-1] = int(max_chunk_length)
- split_idx = len(shape) - 1
- for i in range(len(shape) - 1):
- this_chunk_bytes = itemsize * np.prod(rshape[:i])
- next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
- if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes:
- split_idx = i
-
- rchunks = rshape[:split_idx]
- item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
- this_max_chunk_length = rshape[split_idx]
- next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes))
- rchunks.append(next_chunk_length)
- len_diff = len(shape) - len(rchunks)
- rchunks.extend([1] * len_diff)
- chunks = tuple(rchunks[::-1])
- # print(np.prod(chunks) * itemsize / target_chunk_bytes)
- return chunks
-
-
-class ReplayBuffer:
- """
- Zarr-based temporal datastructure.
- Assumes first dimension to be time. Only chunk in time dimension.
- """
-
- def __init__(self, root: zarr.Group | dict[str, dict]):
- """
- Dummy constructor. Use copy_from* and create_from* class methods instead.
- """
- assert "data" in root
- assert "meta" in root
- assert "episode_ends" in root["meta"]
- for value in root["data"].values():
- assert value.shape[0] == root["meta"]["episode_ends"][-1]
- self.root = root
-
- # ============= create constructors ===============
- @classmethod
- def create_empty_zarr(cls, storage=None, root=None):
- if root is None:
- if storage is None:
- storage = zarr.MemoryStore()
- root = zarr.group(store=storage)
- root.require_group("data", overwrite=False)
- meta = root.require_group("meta", overwrite=False)
- if "episode_ends" not in meta:
- meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False)
- return cls(root=root)
-
- @classmethod
- def create_empty_numpy(cls):
- root = {"data": {}, "meta": {"episode_ends": np.zeros((0,), dtype=np.int64)}}
- return cls(root=root)
-
- @classmethod
- def create_from_group(cls, group, **kwargs):
- if "data" not in group:
- # create from stratch
- buffer = cls.create_empty_zarr(root=group, **kwargs)
- else:
- # already exist
- buffer = cls(root=group, **kwargs)
- return buffer
-
- @classmethod
- def create_from_path(cls, zarr_path, mode="r", **kwargs):
- """
- Open a on-disk zarr directly (for dataset larger than memory).
- Slower.
- """
- group = zarr.open(os.path.expanduser(zarr_path), mode)
- return cls.create_from_group(group, **kwargs)
-
- # ============= copy constructors ===============
- @classmethod
- def copy_from_store(
- cls,
- src_store,
- store=None,
- keys=None,
- chunks: dict[str, tuple] | None = None,
- compressors: dict | str | numcodecs.abc.Codec | None = None,
- if_exists="replace",
- **kwargs,
- ):
- """
- Load to memory.
- """
- src_root = zarr.group(src_store)
- if chunks is None:
- chunks = {}
- if compressors is None:
- compressors = {}
- root = None
- if store is None:
- # numpy backend
- meta = {}
- for key, value in src_root["meta"].items():
- if len(value.shape) == 0:
- meta[key] = np.array(value)
- else:
- meta[key] = value[:]
-
- if keys is None:
- keys = src_root["data"].keys()
- data = {}
- for key in keys:
- arr = src_root["data"][key]
- data[key] = arr[:]
-
- root = {"meta": meta, "data": data}
- else:
- root = zarr.group(store=store)
- # copy without recompression
- n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
- source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
- )
- data_group = root.create_group("data", overwrite=True)
- if keys is None:
- keys = src_root["data"].keys()
- for key in keys:
- value = src_root["data"][key]
- cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value)
- cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value)
- if cks == value.chunks and cpr == value.compressor:
- # copy without recompression
- this_path = "/data/" + key
- n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
- source=src_store,
- dest=store,
- source_path=this_path,
- dest_path=this_path,
- if_exists=if_exists,
- )
- else:
- # copy with recompression
- n_copied, n_skipped, n_bytes_copied = zarr.copy(
- source=value,
- dest=data_group,
- name=key,
- chunks=cks,
- compressor=cpr,
- if_exists=if_exists,
- )
- buffer = cls(root=root)
- return buffer
-
- @classmethod
- def copy_from_path(
- cls,
- zarr_path,
- backend=None,
- store=None,
- keys=None,
- chunks: dict[str, tuple] | None = None,
- compressors: dict | str | numcodecs.abc.Codec | None = None,
- if_exists="replace",
- **kwargs,
- ):
- """
- Copy a on-disk zarr to in-memory compressed.
- Recommended
- """
- if chunks is None:
- chunks = {}
- if compressors is None:
- compressors = {}
- if backend == "numpy":
- print("backend argument is deprecated!")
- store = None
- group = zarr.open(os.path.expanduser(zarr_path), "r")
- return cls.copy_from_store(
- src_store=group.store,
- store=store,
- keys=keys,
- chunks=chunks,
- compressors=compressors,
- if_exists=if_exists,
- **kwargs,
- )
-
- # ============= save methods ===============
- def save_to_store(
- self,
- store,
- chunks: dict[str, tuple] | None = None,
- compressors: str | numcodecs.abc.Codec | dict | None = None,
- if_exists="replace",
- **kwargs,
- ):
- root = zarr.group(store)
- if chunks is None:
- chunks = {}
- if compressors is None:
- compressors = {}
- if self.backend == "zarr":
- # recompression free copy
- n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
- source=self.root.store,
- dest=store,
- source_path="/meta",
- dest_path="/meta",
- if_exists=if_exists,
- )
- else:
- meta_group = root.create_group("meta", overwrite=True)
- # save meta, no chunking
- for key, value in self.root["meta"].items():
- _ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape)
-
- # save data, chunk
- data_group = root.create_group("data", overwrite=True)
- for key, value in self.root["data"].items():
- cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
- cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
- if isinstance(value, zarr.Array):
- if cks == value.chunks and cpr == value.compressor:
- # copy without recompression
- this_path = "/data/" + key
- n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
- source=self.root.store,
- dest=store,
- source_path=this_path,
- dest_path=this_path,
- if_exists=if_exists,
- )
- else:
- # copy with recompression
- n_copied, n_skipped, n_bytes_copied = zarr.copy(
- source=value,
- dest=data_group,
- name=key,
- chunks=cks,
- compressor=cpr,
- if_exists=if_exists,
- )
- else:
- # numpy
- _ = data_group.array(name=key, data=value, chunks=cks, compressor=cpr)
- return store
-
- def save_to_path(
- self,
- zarr_path,
- chunks: dict[str, tuple] | None = None,
- compressors: str | numcodecs.abc.Codec | dict | None = None,
- if_exists="replace",
- **kwargs,
- ):
- if chunks is None:
- chunks = {}
- if compressors is None:
- compressors = {}
- store = zarr.DirectoryStore(os.path.expanduser(zarr_path))
- return self.save_to_store(
- store, chunks=chunks, compressors=compressors, if_exists=if_exists, **kwargs
- )
-
- @staticmethod
- def resolve_compressor(compressor="default"):
- if compressor == "default":
- compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE)
- elif compressor == "disk":
- compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE)
- return compressor
-
- @classmethod
- def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array):
- # allows compressor to be explicitly set to None
- cpr = "nil"
- if isinstance(compressors, dict):
- if key in compressors:
- cpr = cls.resolve_compressor(compressors[key])
- elif isinstance(array, zarr.Array):
- cpr = array.compressor
- else:
- cpr = cls.resolve_compressor(compressors)
- # backup default
- if cpr == "nil":
- cpr = cls.resolve_compressor("default")
- return cpr
-
- @classmethod
- def _resolve_array_chunks(cls, chunks: dict | tuple, key, array):
- cks = None
- if isinstance(chunks, dict):
- if key in chunks:
- cks = chunks[key]
- elif isinstance(array, zarr.Array):
- cks = array.chunks
- elif isinstance(chunks, tuple):
- cks = chunks
- else:
- raise TypeError(f"Unsupported chunks type {type(chunks)}")
- # backup default
- if cks is None:
- cks = get_optimal_chunks(shape=array.shape, dtype=array.dtype)
- # check
- check_chunks_compatible(chunks=cks, shape=array.shape)
- return cks
-
- # ============= properties =================
- @cached_property
- def data(self):
- return self.root["data"]
-
- @cached_property
- def meta(self):
- return self.root["meta"]
-
- def update_meta(self, data):
- # sanitize data
- np_data = {}
- for key, value in data.items():
- if isinstance(value, np.ndarray):
- np_data[key] = value
- else:
- arr = np.array(value)
- if arr.dtype == object:
- raise TypeError(f"Invalid value type {type(value)}")
- np_data[key] = arr
-
- meta_group = self.meta
- if self.backend == "zarr":
- for key, value in np_data.items():
- _ = meta_group.array(
- name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True
- )
- else:
- meta_group.update(np_data)
-
- return meta_group
-
- @property
- def episode_ends(self):
- return self.meta["episode_ends"]
-
- def get_episode_idxs(self):
- import numba
-
- numba.jit(nopython=True)
-
- def _get_episode_idxs(episode_ends):
- result = np.zeros((episode_ends[-1],), dtype=np.int64)
- for i in range(len(episode_ends)):
- start = 0
- if i > 0:
- start = episode_ends[i - 1]
- end = episode_ends[i]
- for idx in range(start, end):
- result[idx] = i
- return result
-
- return _get_episode_idxs(self.episode_ends)
-
- @property
- def backend(self):
- backend = "numpy"
- if isinstance(self.root, zarr.Group):
- backend = "zarr"
- return backend
-
- # =========== dict-like API ==============
- def __repr__(self) -> str:
- if self.backend == "zarr":
- return str(self.root.tree())
- else:
- return super().__repr__()
-
- def keys(self):
- return self.data.keys()
-
- def values(self):
- return self.data.values()
-
- def items(self):
- return self.data.items()
-
- def __getitem__(self, key):
- return self.data[key]
-
- def __contains__(self, key):
- return key in self.data
-
- # =========== our API ==============
- @property
- def n_steps(self):
- if len(self.episode_ends) == 0:
- return 0
- return self.episode_ends[-1]
-
- @property
- def n_episodes(self):
- return len(self.episode_ends)
-
- @property
- def chunk_size(self):
- if self.backend == "zarr":
- return next(iter(self.data.arrays()))[-1].chunks[0]
- return None
-
- @property
- def episode_lengths(self):
- ends = self.episode_ends[:]
- ends = np.insert(ends, 0, 0)
- lengths = np.diff(ends)
- return lengths
-
- def add_episode(
- self,
- data: dict[str, np.ndarray],
- chunks: dict[str, tuple] | None = None,
- compressors: str | numcodecs.abc.Codec | dict | None = None,
- ):
- if chunks is None:
- chunks = {}
- if compressors is None:
- compressors = {}
- assert len(data) > 0
- is_zarr = self.backend == "zarr"
-
- curr_len = self.n_steps
- episode_length = None
- for value in data.values():
- assert len(value.shape) >= 1
- if episode_length is None:
- episode_length = len(value)
- else:
- assert episode_length == len(value)
- new_len = curr_len + episode_length
-
- for key, value in data.items():
- new_shape = (new_len,) + value.shape[1:]
- # create array
- if key not in self.data:
- if is_zarr:
- cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
- cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
- arr = self.data.zeros(
- name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr
- )
- else:
- # copy data to prevent modify
- arr = np.zeros(shape=new_shape, dtype=value.dtype)
- self.data[key] = arr
- else:
- arr = self.data[key]
- assert value.shape[1:] == arr.shape[1:]
- # same method for both zarr and numpy
- if is_zarr:
- arr.resize(new_shape)
- else:
- arr.resize(new_shape, refcheck=False)
- # copy data
- arr[-value.shape[0] :] = value
-
- # append to episode ends
- episode_ends = self.episode_ends
- if is_zarr:
- episode_ends.resize(episode_ends.shape[0] + 1)
- else:
- episode_ends.resize(episode_ends.shape[0] + 1, refcheck=False)
- episode_ends[-1] = new_len
-
- # rechunk
- if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]:
- rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5))
-
- def drop_episode(self):
- is_zarr = self.backend == "zarr"
- episode_ends = self.episode_ends[:].copy()
- assert len(episode_ends) > 0
- start_idx = 0
- if len(episode_ends) > 1:
- start_idx = episode_ends[-2]
- for value in self.data.values():
- new_shape = (start_idx,) + value.shape[1:]
- if is_zarr:
- value.resize(new_shape)
- else:
- value.resize(new_shape, refcheck=False)
- if is_zarr:
- self.episode_ends.resize(len(episode_ends) - 1)
- else:
- self.episode_ends.resize(len(episode_ends) - 1, refcheck=False)
-
- def pop_episode(self):
- assert self.n_episodes > 0
- episode = self.get_episode(self.n_episodes - 1, copy=True)
- self.drop_episode()
- return episode
-
- def extend(self, data):
- self.add_episode(data)
-
- def get_episode(self, idx, copy=False):
- idx = list(range(len(self.episode_ends)))[idx]
- start_idx = 0
- if idx > 0:
- start_idx = self.episode_ends[idx - 1]
- end_idx = self.episode_ends[idx]
- result = self.get_steps_slice(start_idx, end_idx, copy=copy)
- return result
-
- def get_episode_slice(self, idx):
- start_idx = 0
- if idx > 0:
- start_idx = self.episode_ends[idx - 1]
- end_idx = self.episode_ends[idx]
- return slice(start_idx, end_idx)
-
- def get_steps_slice(self, start, stop, step=None, copy=False):
- _slice = slice(start, stop, step)
-
- result = {}
- for key, value in self.data.items():
- x = value[_slice]
- if copy and isinstance(value, np.ndarray):
- x = x.copy()
- result[key] = x
- return result
-
- # =========== chunking =============
- def get_chunks(self) -> dict:
- assert self.backend == "zarr"
- chunks = {}
- for key, value in self.data.items():
- chunks[key] = value.chunks
- return chunks
-
- def set_chunks(self, chunks: dict):
- assert self.backend == "zarr"
- for key, value in chunks.items():
- if key in self.data:
- arr = self.data[key]
- if value != arr.chunks:
- check_chunks_compatible(chunks=value, shape=arr.shape)
- rechunk_recompress_array(self.data, key, chunks=value)
-
- def get_compressors(self) -> dict:
- assert self.backend == "zarr"
- compressors = {}
- for key, value in self.data.items():
- compressors[key] = value.compressor
- return compressors
-
- def set_compressors(self, compressors: dict):
- assert self.backend == "zarr"
- for key, value in compressors.items():
- if key in self.data:
- arr = self.data[key]
- compressor = self.resolve_compressor(value)
- if compressor != arr.compressor:
- rechunk_recompress_array(self.data, key, compressor=compressor)
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py
deleted file mode 100644
index edeaf093..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py
+++ /dev/null
@@ -1,202 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-This file contains download scripts for raw datasets.
-
-Example of usage:
-```
-python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
---raw-dir data/lerobot-raw/pusht_raw \
---repo-id lerobot-raw/pusht_raw
-```
-"""
-
-import argparse
-import logging
-import warnings
-from pathlib import Path
-
-from huggingface_hub import snapshot_download
-
-from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
-
-# {raw_repo_id: raw_format}
-AVAILABLE_RAW_REPO_IDS = {
- "lerobot-raw/aloha_mobile_cabinet_raw": "aloha_hdf5",
- "lerobot-raw/aloha_mobile_chair_raw": "aloha_hdf5",
- "lerobot-raw/aloha_mobile_elevator_raw": "aloha_hdf5",
- "lerobot-raw/aloha_mobile_shrimp_raw": "aloha_hdf5",
- "lerobot-raw/aloha_mobile_wash_pan_raw": "aloha_hdf5",
- "lerobot-raw/aloha_mobile_wipe_wine_raw": "aloha_hdf5",
- "lerobot-raw/aloha_sim_insertion_human_raw": "aloha_hdf5",
- "lerobot-raw/aloha_sim_insertion_scripted_raw": "aloha_hdf5",
- "lerobot-raw/aloha_sim_transfer_cube_human_raw": "aloha_hdf5",
- "lerobot-raw/aloha_sim_transfer_cube_scripted_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_battery_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_candy_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_coffee_new_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_coffee_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_cups_open_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_fork_pick_up_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_pingpong_test_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_pro_pencil_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_screw_driver_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_tape_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_thread_velcro_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_towel_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_vinh_cup_left_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_vinh_cup_raw": "aloha_hdf5",
- "lerobot-raw/aloha_static_ziploc_slide_raw": "aloha_hdf5",
- "lerobot-raw/umi_cup_in_the_wild_raw": "umi_zarr",
- "lerobot-raw/pusht_raw": "pusht_zarr",
- "lerobot-raw/unitreeh1_fold_clothes_raw": "aloha_hdf5",
- "lerobot-raw/unitreeh1_rearrange_objects_raw": "aloha_hdf5",
- "lerobot-raw/unitreeh1_two_robot_greeting_raw": "aloha_hdf5",
- "lerobot-raw/unitreeh1_warehouse_raw": "aloha_hdf5",
- "lerobot-raw/xarm_lift_medium_raw": "xarm_pkl",
- "lerobot-raw/xarm_lift_medium_replay_raw": "xarm_pkl",
- "lerobot-raw/xarm_push_medium_raw": "xarm_pkl",
- "lerobot-raw/xarm_push_medium_replay_raw": "xarm_pkl",
- "lerobot-raw/fractal20220817_data_raw": "openx_rlds.fractal20220817_data",
- "lerobot-raw/kuka_raw": "openx_rlds.kuka",
- "lerobot-raw/bridge_openx_raw": "openx_rlds.bridge_openx",
- "lerobot-raw/taco_play_raw": "openx_rlds.taco_play",
- "lerobot-raw/jaco_play_raw": "openx_rlds.jaco_play",
- "lerobot-raw/berkeley_cable_routing_raw": "openx_rlds.berkeley_cable_routing",
- "lerobot-raw/roboturk_raw": "openx_rlds.roboturk",
- "lerobot-raw/nyu_door_opening_surprising_effectiveness_raw": "openx_rlds.nyu_door_opening_surprising_effectiveness",
- "lerobot-raw/viola_raw": "openx_rlds.viola",
- "lerobot-raw/berkeley_autolab_ur5_raw": "openx_rlds.berkeley_autolab_ur5",
- "lerobot-raw/toto_raw": "openx_rlds.toto",
- "lerobot-raw/language_table_raw": "openx_rlds.language_table",
- "lerobot-raw/columbia_cairlab_pusht_real_raw": "openx_rlds.columbia_cairlab_pusht_real",
- "lerobot-raw/stanford_kuka_multimodal_dataset_raw": "openx_rlds.stanford_kuka_multimodal_dataset",
- "lerobot-raw/nyu_rot_dataset_raw": "openx_rlds.nyu_rot_dataset",
- "lerobot-raw/io_ai_tech_raw": "openx_rlds.io_ai_tech",
- "lerobot-raw/stanford_hydra_dataset_raw": "openx_rlds.stanford_hydra_dataset",
- "lerobot-raw/austin_buds_dataset_raw": "openx_rlds.austin_buds_dataset",
- "lerobot-raw/nyu_franka_play_dataset_raw": "openx_rlds.nyu_franka_play_dataset",
- "lerobot-raw/maniskill_dataset_raw": "openx_rlds.maniskill_dataset",
- "lerobot-raw/furniture_bench_dataset_raw": "openx_rlds.furniture_bench_dataset",
- "lerobot-raw/cmu_franka_exploration_dataset_raw": "openx_rlds.cmu_franka_exploration_dataset",
- "lerobot-raw/ucsd_kitchen_dataset_raw": "openx_rlds.ucsd_kitchen_dataset",
- "lerobot-raw/ucsd_pick_and_place_dataset_raw": "openx_rlds.ucsd_pick_and_place_dataset",
- "lerobot-raw/spoc_raw": "openx_rlds.spoc",
- "lerobot-raw/austin_sailor_dataset_raw": "openx_rlds.austin_sailor_dataset",
- "lerobot-raw/austin_sirius_dataset_raw": "openx_rlds.austin_sirius_dataset",
- "lerobot-raw/bc_z_raw": "openx_rlds.bc_z",
- "lerobot-raw/utokyo_pr2_opening_fridge_raw": "openx_rlds.utokyo_pr2_opening_fridge",
- "lerobot-raw/utokyo_pr2_tabletop_manipulation_raw": "openx_rlds.utokyo_pr2_tabletop_manipulation",
- "lerobot-raw/utokyo_xarm_pick_and_place_raw": "openx_rlds.utokyo_xarm_pick_and_place",
- "lerobot-raw/utokyo_xarm_bimanual_raw": "openx_rlds.utokyo_xarm_bimanual",
- "lerobot-raw/utokyo_saytap_raw": "openx_rlds.utokyo_saytap",
- "lerobot-raw/robo_net_raw": "openx_rlds.robo_net",
- "lerobot-raw/robo_set_raw": "openx_rlds.robo_set",
- "lerobot-raw/berkeley_mvp_raw": "openx_rlds.berkeley_mvp",
- "lerobot-raw/berkeley_rpt_raw": "openx_rlds.berkeley_rpt",
- "lerobot-raw/kaist_nonprehensile_raw": "openx_rlds.kaist_nonprehensile",
- "lerobot-raw/stanford_mask_vit_raw": "openx_rlds.stanford_mask_vit",
- "lerobot-raw/tokyo_u_lsmo_raw": "openx_rlds.tokyo_u_lsmo",
- "lerobot-raw/dlr_sara_pour_raw": "openx_rlds.dlr_sara_pour",
- "lerobot-raw/dlr_sara_grid_clamp_raw": "openx_rlds.dlr_sara_grid_clamp",
- "lerobot-raw/dlr_edan_shared_control_raw": "openx_rlds.dlr_edan_shared_control",
- "lerobot-raw/asu_table_top_raw": "openx_rlds.asu_table_top",
- "lerobot-raw/stanford_robocook_raw": "openx_rlds.stanford_robocook",
- "lerobot-raw/imperialcollege_sawyer_wrist_cam_raw": "openx_rlds.imperialcollege_sawyer_wrist_cam",
- "lerobot-raw/iamlab_cmu_pickup_insert_raw": "openx_rlds.iamlab_cmu_pickup_insert",
- "lerobot-raw/uiuc_d3field_raw": "openx_rlds.uiuc_d3field",
- "lerobot-raw/utaustin_mutex_raw": "openx_rlds.utaustin_mutex",
- "lerobot-raw/berkeley_fanuc_manipulation_raw": "openx_rlds.berkeley_fanuc_manipulation",
- "lerobot-raw/cmu_playing_with_food_raw": "openx_rlds.cmu_playing_with_food",
- "lerobot-raw/cmu_play_fusion_raw": "openx_rlds.cmu_play_fusion",
- "lerobot-raw/cmu_stretch_raw": "openx_rlds.cmu_stretch",
- "lerobot-raw/berkeley_gnm_recon_raw": "openx_rlds.berkeley_gnm_recon",
- "lerobot-raw/berkeley_gnm_cory_hall_raw": "openx_rlds.berkeley_gnm_cory_hall",
- "lerobot-raw/berkeley_gnm_sac_son_raw": "openx_rlds.berkeley_gnm_sac_son",
- "lerobot-raw/droid_raw": "openx_rlds.droid",
- "lerobot-raw/droid_100_raw": "openx_rlds.droid100",
- "lerobot-raw/fmb_raw": "openx_rlds.fmb",
- "lerobot-raw/dobbe_raw": "openx_rlds.dobbe",
- "lerobot-raw/usc_cloth_sim_raw": "openx_rlds.usc_cloth_sim",
- "lerobot-raw/plex_robosuite_raw": "openx_rlds.plex_robosuite",
- "lerobot-raw/conq_hose_manipulation_raw": "openx_rlds.conq_hose_manipulation",
- "lerobot-raw/vima_raw": "openx_rlds.vima",
- "lerobot-raw/robot_vqa_raw": "openx_rlds.robot_vqa",
- "lerobot-raw/mimic_play_raw": "openx_rlds.mimic_play",
- "lerobot-raw/tidybot_raw": "openx_rlds.tidybot",
- "lerobot-raw/eth_agent_affordances_raw": "openx_rlds.eth_agent_affordances",
-}
-
-
-def download_raw(raw_dir: Path, repo_id: str):
- check_repo_id(repo_id)
- user_id, dataset_id = repo_id.split("/")
-
- if not dataset_id.endswith("_raw"):
- warnings.warn(
- f"""`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this
- naming convention by renaming your repository is advised, but not mandatory.""",
- stacklevel=1,
- )
-
- # Send warning if raw_dir isn't well formated
- if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
- warnings.warn(
- f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that
- match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised,
- but not mandatory.""",
- stacklevel=1,
- )
- raw_dir.mkdir(parents=True, exist_ok=True)
-
- logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
- snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
- logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
-
-
-def download_all_raw_datasets(data_dir: Path | None = None):
- if data_dir is None:
- data_dir = Path("data")
- for repo_id in AVAILABLE_RAW_REPO_IDS:
- raw_dir = data_dir / repo_id
- download_raw(raw_dir, repo_id)
-
-
-def main():
- parser = argparse.ArgumentParser(
- description=f"""A script to download raw datasets from Hugging Face hub to a local directory. Here is a
- non exhaustive list of available repositories to use in `--repo-id`: {list(AVAILABLE_RAW_REPO_IDS.keys())}""",
- )
-
- parser.add_argument(
- "--raw-dir",
- type=Path,
- required=True,
- help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
- )
- parser.add_argument(
- "--repo-id",
- type=str,
- required=True,
- help="""Repositery identifier on Hugging Face: a community or a user name `/` the name of
- the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).""",
- )
- args = parser.parse_args()
- download_raw(**vars(args))
-
-
-if __name__ == "__main__":
- main()
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_encode_datasets.py b/lerobot/common/datasets/push_dataset_to_hub/_encode_datasets.py
deleted file mode 100644
index 184d79fb..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_encode_datasets.py
+++ /dev/null
@@ -1,184 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Use this script to batch encode lerobot dataset from their raw format to LeRobotDataset and push their updated
-version to the hub. Under the hood, this script reuses 'push_dataset_to_hub.py'. It assumes that you already
-downloaded raw datasets, which you can do with the related '_download_raw.py' script.
-
-For instance, for codebase_version = 'v1.6', the following command was run, assuming raw datasets from
-lerobot-raw were downloaded in 'raw/datasets/directory':
-```bash
-python lerobot/common/datasets/push_dataset_to_hub/_encode_datasets.py \
- --raw-dir raw/datasets/directory \
- --raw-repo-ids lerobot-raw \
- --local-dir push/datasets/directory \
- --tests-data-dir tests/data \
- --push-repo lerobot \
- --vcodec libsvtav1 \
- --pix-fmt yuv420p \
- --g 2 \
- --crf 30
-```
-"""
-
-import argparse
-from pathlib import Path
-
-from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
-from lerobot.common.datasets.push_dataset_to_hub._download_raw import AVAILABLE_RAW_REPO_IDS
-from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
-from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
-
-
-def get_push_repo_id_from_raw(raw_repo_id: str, push_repo: str) -> str:
- dataset_id_raw = raw_repo_id.split("/")[1]
- dataset_id = dataset_id_raw.removesuffix("_raw")
- return f"{push_repo}/{dataset_id}"
-
-
-def encode_datasets(
- raw_dir: Path,
- raw_repo_ids: list[str],
- push_repo: str,
- vcodec: str,
- pix_fmt: str,
- g: int,
- crf: int,
- local_dir: Path | None = None,
- tests_data_dir: Path | None = None,
- raw_format: str | None = None,
- dry_run: bool = False,
-) -> None:
- if len(raw_repo_ids) == 1 and raw_repo_ids[0].lower() == "lerobot-raw":
- raw_repo_ids_format = AVAILABLE_RAW_REPO_IDS
- else:
- if raw_format is None:
- raise ValueError(raw_format)
- raw_repo_ids_format = {id_: raw_format for id_ in raw_repo_ids}
-
- for raw_repo_id, repo_raw_format in raw_repo_ids_format.items():
- check_repo_id(raw_repo_id)
- dataset_repo_id_push = get_push_repo_id_from_raw(raw_repo_id, push_repo)
- dataset_raw_dir = raw_dir / raw_repo_id
- dataset_dir = local_dir / dataset_repo_id_push if local_dir is not None else None
- encoding = {
- "vcodec": vcodec,
- "pix_fmt": pix_fmt,
- "g": g,
- "crf": crf,
- }
-
- if not (dataset_raw_dir).is_dir():
- raise NotADirectoryError(dataset_raw_dir)
-
- if not dry_run:
- push_dataset_to_hub(
- dataset_raw_dir,
- raw_format=repo_raw_format,
- repo_id=dataset_repo_id_push,
- local_dir=dataset_dir,
- resume=True,
- encoding=encoding,
- tests_data_dir=tests_data_dir,
- )
- else:
- print(
- f"DRY RUN: {dataset_raw_dir} --> {dataset_dir} --> {dataset_repo_id_push}@{CODEBASE_VERSION}"
- )
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--raw-dir",
- type=Path,
- default=Path("data"),
- help="Directory where raw datasets are located.",
- )
- parser.add_argument(
- "--raw-repo-ids",
- type=str,
- nargs="*",
- default=["lerobot-raw"],
- help="""Raw dataset repo ids. if 'lerobot-raw', the keys from `AVAILABLE_RAW_REPO_IDS` will be
- used and raw datasets will be fetched from the 'lerobot-raw/' repo and pushed with their
- associated format. It is assumed that each dataset is located at `raw_dir / raw_repo_id` """,
- )
- parser.add_argument(
- "--raw-format",
- type=str,
- default=None,
- help="""Raw format to use for the raw repo-ids. Must be specified if --raw-repo-ids is not
- 'lerobot-raw'""",
- )
- parser.add_argument(
- "--local-dir",
- type=Path,
- default=None,
- help="""When provided, writes the dataset converted to LeRobotDataset format in this directory
- (e.g. `data/lerobot/aloha_mobile_chair`).""",
- )
- parser.add_argument(
- "--push-repo",
- type=str,
- default="lerobot",
- help="Repo to upload datasets to",
- )
- parser.add_argument(
- "--vcodec",
- type=str,
- default="libsvtav1",
- help="Codec to use for encoding videos",
- )
- parser.add_argument(
- "--pix-fmt",
- type=str,
- default="yuv420p",
- help="Pixel formats (chroma subsampling) to be used for encoding",
- )
- parser.add_argument(
- "--g",
- type=int,
- default=2,
- help="Group of pictures sizes to be used for encoding.",
- )
- parser.add_argument(
- "--crf",
- type=int,
- default=30,
- help="Constant rate factors to be used for encoding.",
- )
- parser.add_argument(
- "--tests-data-dir",
- type=Path,
- default=None,
- help=(
- "When provided, save tests artifacts into the given directory "
- "(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
- ),
- )
- parser.add_argument(
- "--dry-run",
- type=int,
- default=0,
- help="If not set to 0, this script won't download or upload anything.",
- )
- args = parser.parse_args()
- encode_datasets(**vars(args))
-
-
-if __name__ == "__main__":
- main()
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py b/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py
deleted file mode 100644
index a118b7e7..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py
+++ /dev/null
@@ -1,326 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# imagecodecs/numcodecs.py
-
-# Copyright (c) 2021-2022, Christoph Gohlke
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are met:
-#
-# 1. Redistributions of source code must retain the above copyright notice,
-# this list of conditions and the following disclaimer.
-#
-# 2. Redistributions in binary form must reproduce the above copyright notice,
-# this list of conditions and the following disclaimer in the documentation
-# and/or other materials provided with the distribution.
-#
-# 3. Neither the name of the copyright holder nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
-# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
-# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
-# POSSIBILITY OF SUCH DAMAGE.
-
-# Copied from: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/codecs/imagecodecs_numcodecs.py#L1
-"""Additional numcodecs implemented using imagecodecs."""
-
-__version__ = "2022.9.26"
-
-__all__ = ("register_codecs",)
-
-import imagecodecs
-import numpy
-from numcodecs.abc import Codec
-from numcodecs.registry import get_codec, register_codec
-
-# TODO (azouitine): Remove useless codecs
-
-
-def protective_squeeze(x: numpy.ndarray):
- """
- Squeeze dim only if it's not the last dim.
- Image dim expected to be *, H, W, C
- """
- img_shape = x.shape[-3:]
- if len(x.shape) > 3:
- n_imgs = numpy.prod(x.shape[:-3])
- if n_imgs > 1:
- img_shape = (-1,) + img_shape
- return x.reshape(img_shape)
-
-
-def get_default_image_compressor(**kwargs):
- if imagecodecs.JPEGXL:
- # has JPEGXL
- this_kwargs = {
- "effort": 3,
- "distance": 0.3,
- # bug in libjxl, invalid codestream for non-lossless
- # when decoding speed > 1
- "decodingspeed": 1,
- }
- this_kwargs.update(kwargs)
- return JpegXl(**this_kwargs)
- else:
- this_kwargs = {"level": 50}
- this_kwargs.update(kwargs)
- return Jpeg2k(**this_kwargs)
-
-
-class Jpeg2k(Codec):
- """JPEG 2000 codec for numcodecs."""
-
- codec_id = "imagecodecs_jpeg2k"
-
- def __init__(
- self,
- level=None,
- codecformat=None,
- colorspace=None,
- tile=None,
- reversible=None,
- bitspersample=None,
- resolutions=None,
- numthreads=None,
- verbose=0,
- ):
- self.level = level
- self.codecformat = codecformat
- self.colorspace = colorspace
- self.tile = None if tile is None else tuple(tile)
- self.reversible = reversible
- self.bitspersample = bitspersample
- self.resolutions = resolutions
- self.numthreads = numthreads
- self.verbose = verbose
-
- def encode(self, buf):
- buf = protective_squeeze(numpy.asarray(buf))
- return imagecodecs.jpeg2k_encode(
- buf,
- level=self.level,
- codecformat=self.codecformat,
- colorspace=self.colorspace,
- tile=self.tile,
- reversible=self.reversible,
- bitspersample=self.bitspersample,
- resolutions=self.resolutions,
- numthreads=self.numthreads,
- verbose=self.verbose,
- )
-
- def decode(self, buf, out=None):
- return imagecodecs.jpeg2k_decode(buf, verbose=self.verbose, numthreads=self.numthreads, out=out)
-
-
-class JpegXl(Codec):
- """JPEG XL codec for numcodecs."""
-
- codec_id = "imagecodecs_jpegxl"
-
- def __init__(
- self,
- # encode
- level=None,
- effort=None,
- distance=None,
- lossless=None,
- decodingspeed=None,
- photometric=None,
- planar=None,
- usecontainer=None,
- # decode
- index=None,
- keeporientation=None,
- # both
- numthreads=None,
- ):
- """
- Return JPEG XL image from numpy array.
- Float must be in nominal range 0..1.
-
- Currently L, LA, RGB, RGBA images are supported in contig mode.
- Extra channels are only supported for grayscale images in planar mode.
-
- Parameters
- ----------
- level : Default to None, i.e. not overwriting lossess and decodingspeed options.
- When < 0: Use lossless compression
- When in [0,1,2,3,4]: Sets the decoding speed tier for the provided options.
- Minimum is 0 (slowest to decode, best quality/density), and maximum
- is 4 (fastest to decode, at the cost of some quality/density).
- effort : Default to 3.
- Sets encoder effort/speed level without affecting decoding speed.
- Valid values are, from faster to slower speed: 1:lightning 2:thunder
- 3:falcon 4:cheetah 5:hare 6:wombat 7:squirrel 8:kitten 9:tortoise.
- Speed: lightning, thunder, falcon, cheetah, hare, wombat, squirrel, kitten, tortoise
- control the encoder effort in ascending order.
- This also affects memory usage: using lower effort will typically reduce memory
- consumption during encoding.
- lightning and thunder are fast modes useful for lossless mode (modular).
- falcon disables all of the following tools.
- cheetah enables coefficient reordering, context clustering, and heuristics for selecting DCT sizes and quantization steps.
- hare enables Gaborish filtering, chroma from luma, and an initial estimate of quantization steps.
- wombat enables error diffusion quantization and full DCT size selection heuristics.
- squirrel (default) enables dots, patches, and spline detection, and full context clustering.
- kitten optimizes the adaptive quantization for a psychovisual metric.
- tortoise enables a more thorough adaptive quantization search.
- distance : Default to 1.0
- Sets the distance level for lossy compression: target max butteraugli distance,
- lower = higher quality. Range: 0 .. 15. 0.0 = mathematically lossless
- (however, use JxlEncoderSetFrameLossless instead to use true lossless,
- as setting distance to 0 alone is not the only requirement).
- 1.0 = visually lossless. Recommended range: 0.5 .. 3.0.
- lossess : Default to False.
- Use lossess encoding.
- decodingspeed : Default to 0.
- Duplicate to level. [0,4]
- photometric : Return JxlColorSpace value.
- Default logic is quite complicated but works most of the time.
- Accepted value:
- int: [-1,3]
- str: ['RGB',
- 'WHITEISZERO', 'MINISWHITE',
- 'BLACKISZERO', 'MINISBLACK', 'GRAY',
- 'XYB', 'KNOWN']
- planar : Enable multi-channel mode.
- Default to false.
- usecontainer :
- Forces the encoder to use the box-based container format (BMFF)
- even when not necessary.
- When using JxlEncoderUseBoxes, JxlEncoderStoreJPEGMetadata or
- JxlEncoderSetCodestreamLevel with level 10, the encoder will
- automatically also use the container format, it is not necessary
- to use JxlEncoderUseContainer for those use cases.
- By default this setting is disabled.
- index : Selectively decode frames for animation.
- Default to 0, decode all frames.
- When set to > 0, decode that frame index only.
- keeporientation :
- Enables or disables preserving of as-in-bitstream pixeldata orientation.
- Some images are encoded with an Orientation tag indicating that the
- decoder must perform a rotation and/or mirroring to the encoded image data.
-
- If skip_reorientation is JXL_FALSE (the default): the decoder will apply
- the transformation from the orientation setting, hence rendering the image
- according to its specified intent. When producing a JxlBasicInfo, the decoder
- will always set the orientation field to JXL_ORIENT_IDENTITY (matching the
- returned pixel data) and also align xsize and ysize so that they correspond
- to the width and the height of the returned pixel data.
-
- If skip_reorientation is JXL_TRUE: the decoder will skip applying the
- transformation from the orientation setting, returning the image in
- the as-in-bitstream pixeldata orientation. This may be faster to decode
- since the decoder doesnt have to apply the transformation, but can
- cause wrong display of the image if the orientation tag is not correctly
- taken into account by the user.
-
- By default, this option is disabled, and the returned pixel data is
- re-oriented according to the images Orientation setting.
- threads : Default to 1.
- If <= 0, use all cores.
- If > 32, clipped to 32.
- """
-
- self.level = level
- self.effort = effort
- self.distance = distance
- self.lossless = bool(lossless)
- self.decodingspeed = decodingspeed
- self.photometric = photometric
- self.planar = planar
- self.usecontainer = usecontainer
- self.index = index
- self.keeporientation = keeporientation
- self.numthreads = numthreads
-
- def encode(self, buf):
- # TODO: only squeeze all but last dim
- buf = protective_squeeze(numpy.asarray(buf))
- return imagecodecs.jpegxl_encode(
- buf,
- level=self.level,
- effort=self.effort,
- distance=self.distance,
- lossless=self.lossless,
- decodingspeed=self.decodingspeed,
- photometric=self.photometric,
- planar=self.planar,
- usecontainer=self.usecontainer,
- numthreads=self.numthreads,
- )
-
- def decode(self, buf, out=None):
- return imagecodecs.jpegxl_decode(
- buf,
- index=self.index,
- keeporientation=self.keeporientation,
- numthreads=self.numthreads,
- out=out,
- )
-
-
-def _flat(out):
- """Return numpy array as contiguous view of bytes if possible."""
- if out is None:
- return None
- view = memoryview(out)
- if view.readonly or not view.contiguous:
- return None
- return view.cast("B")
-
-
-def register_codecs(codecs=None, force=False, verbose=True):
- """Register codecs in this module with numcodecs."""
- for name, cls in globals().items():
- if not hasattr(cls, "codec_id") or name == "Codec":
- continue
- if codecs is not None and cls.codec_id not in codecs:
- continue
- try:
- try: # noqa: SIM105
- get_codec({"id": cls.codec_id})
- except TypeError:
- # registered, but failed
- pass
- except ValueError:
- # not registered yet
- pass
- else:
- if not force:
- if verbose:
- log_warning(f"numcodec {cls.codec_id!r} already registered")
- continue
- if verbose:
- log_warning(f"replacing registered numcodec {cls.codec_id!r}")
- register_codec(cls)
-
-
-def log_warning(msg, *args, **kwargs):
- """Log message with level WARNING."""
- import logging
-
- logging.getLogger(__name__).warning(msg, *args, **kwargs)
diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py
deleted file mode 100644
index e2973ef8..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py
+++ /dev/null
@@ -1,233 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
-"""
-
-import gc
-import shutil
-from pathlib import Path
-
-import h5py
-import numpy as np
-import torch
-import tqdm
-from datasets import Dataset, Features, Image, Sequence, Value
-from PIL import Image as PILImage
-
-from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
-from lerobot.common.datasets.push_dataset_to_hub.utils import (
- calculate_episode_data_index,
- concatenate_episodes,
- get_default_encoding,
- save_images_concurrently,
-)
-from lerobot.common.datasets.utils import (
- hf_transform_to_torch,
-)
-from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
-
-
-def get_cameras(hdf5_data):
- # ignore depth channel, not currently handled
- # TODO(rcadene): add depth
- rgb_cameras = [key for key in hdf5_data["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
- return rgb_cameras
-
-
-def check_format(raw_dir) -> bool:
- # only frames from simulation are uncompressed
- compressed_images = "sim" not in raw_dir.name
-
- hdf5_paths = list(raw_dir.glob("episode_*.hdf5"))
- assert len(hdf5_paths) != 0
- for hdf5_path in hdf5_paths:
- with h5py.File(hdf5_path, "r") as data:
- assert "/action" in data
- assert "/observations/qpos" in data
-
- assert data["/action"].ndim == 2
- assert data["/observations/qpos"].ndim == 2
-
- num_frames = data["/action"].shape[0]
- assert num_frames == data["/observations/qpos"].shape[0]
-
- for camera in get_cameras(data):
- assert num_frames == data[f"/observations/images/{camera}"].shape[0]
-
- if compressed_images:
- assert data[f"/observations/images/{camera}"].ndim == 2
- else:
- assert data[f"/observations/images/{camera}"].ndim == 4
- b, h, w, c = data[f"/observations/images/{camera}"].shape
- assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
-
-
-def load_from_raw(
- raw_dir: Path,
- videos_dir: Path,
- fps: int,
- video: bool,
- episodes: list[int] | None = None,
- encoding: dict | None = None,
-):
- # only frames from simulation are uncompressed
- compressed_images = "sim" not in raw_dir.name
-
- hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
- num_episodes = len(hdf5_files)
-
- ep_dicts = []
- ep_ids = episodes if episodes else range(num_episodes)
- for ep_idx in tqdm.tqdm(ep_ids):
- ep_path = hdf5_files[ep_idx]
- with h5py.File(ep_path, "r") as ep:
- num_frames = ep["/action"].shape[0]
-
- # last step of demonstration is considered done
- done = torch.zeros(num_frames, dtype=torch.bool)
- done[-1] = True
-
- state = torch.from_numpy(ep["/observations/qpos"][:])
- action = torch.from_numpy(ep["/action"][:])
- if "/observations/qvel" in ep:
- velocity = torch.from_numpy(ep["/observations/qvel"][:])
- if "/observations/effort" in ep:
- effort = torch.from_numpy(ep["/observations/effort"][:])
-
- ep_dict = {}
-
- for camera in get_cameras(ep):
- img_key = f"observation.images.{camera}"
-
- if compressed_images:
- import cv2
-
- # load one compressed image after the other in RAM and uncompress
- imgs_array = []
- for data in ep[f"/observations/images/{camera}"]:
- imgs_array.append(cv2.imdecode(data, 1))
- imgs_array = np.array(imgs_array)
-
- else:
- # load all images in RAM
- imgs_array = ep[f"/observations/images/{camera}"][:]
-
- if video:
- # save png images in temporary directory
- tmp_imgs_dir = videos_dir / "tmp_images"
- save_images_concurrently(imgs_array, tmp_imgs_dir)
-
- # encode images to a mp4 video
- fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
- video_path = videos_dir / fname
- encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
-
- # clean temporary images directory
- shutil.rmtree(tmp_imgs_dir)
-
- # store the reference to the video frame
- ep_dict[img_key] = [
- {"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
- ]
- else:
- ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
-
- ep_dict["observation.state"] = state
- if "/observations/velocity" in ep:
- ep_dict["observation.velocity"] = velocity
- if "/observations/effort" in ep:
- ep_dict["observation.effort"] = effort
- ep_dict["action"] = action
- ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
- ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
- ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
- ep_dict["next.done"] = done
- # TODO(rcadene): add reward and success by computing them in sim
-
- assert isinstance(ep_idx, int)
- ep_dicts.append(ep_dict)
-
- gc.collect()
-
- data_dict = concatenate_episodes(ep_dicts)
-
- total_frames = data_dict["frame_index"].shape[0]
- data_dict["index"] = torch.arange(0, total_frames, 1)
- return data_dict
-
-
-def to_hf_dataset(data_dict, video) -> Dataset:
- features = {}
-
- keys = [key for key in data_dict if "observation.images." in key]
- for key in keys:
- if video:
- features[key] = VideoFrame()
- else:
- features[key] = Image()
-
- features["observation.state"] = Sequence(
- length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
- )
- if "observation.velocity" in data_dict:
- features["observation.velocity"] = Sequence(
- length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
- )
- if "observation.effort" in data_dict:
- features["observation.effort"] = Sequence(
- length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
- )
- features["action"] = Sequence(
- length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
- )
- features["episode_index"] = Value(dtype="int64", id=None)
- features["frame_index"] = Value(dtype="int64", id=None)
- features["timestamp"] = Value(dtype="float32", id=None)
- features["next.done"] = Value(dtype="bool", id=None)
- features["index"] = Value(dtype="int64", id=None)
-
- hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
- hf_dataset.set_transform(hf_transform_to_torch)
- return hf_dataset
-
-
-def from_raw_to_lerobot_format(
- raw_dir: Path,
- videos_dir: Path,
- fps: int | None = None,
- video: bool = True,
- episodes: list[int] | None = None,
- encoding: dict | None = None,
-):
- # sanity check
- check_format(raw_dir)
-
- if fps is None:
- fps = 50
-
- data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
- hf_dataset = to_hf_dataset(data_dict, video)
- episode_data_index = calculate_episode_data_index(hf_dataset)
- info = {
- "codebase_version": CODEBASE_VERSION,
- "fps": fps,
- "video": video,
- }
- if video:
- info["encoding"] = get_default_encoding()
-
- return hf_dataset, episode_data_index, info
diff --git a/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py b/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py
deleted file mode 100644
index 26492576..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py
+++ /dev/null
@@ -1,107 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Contains utilities to process raw data format of png images files recorded with capture_camera_feed.py
-"""
-
-from pathlib import Path
-
-import torch
-from datasets import Dataset, Features, Image, Value
-from PIL import Image as PILImage
-
-from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
-from lerobot.common.datasets.push_dataset_to_hub.utils import (
- calculate_episode_data_index,
- concatenate_episodes,
-)
-from lerobot.common.datasets.utils import hf_transform_to_torch
-from lerobot.common.datasets.video_utils import VideoFrame
-
-
-def check_format(raw_dir: Path) -> bool:
- image_paths = list(raw_dir.glob("frame_*.png"))
- if len(image_paths) == 0:
- raise ValueError
-
-
-def load_from_raw(raw_dir: Path, fps: int, episodes: list[int] | None = None):
- if episodes is not None:
- # TODO(aliberts): add support for multi-episodes.
- raise NotImplementedError()
-
- ep_dict = {}
- ep_idx = 0
-
- image_paths = sorted(raw_dir.glob("frame_*.png"))
- num_frames = len(image_paths)
-
- ep_dict["observation.image"] = [PILImage.open(x) for x in image_paths]
- ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
- ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
- ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
-
- ep_dicts = [ep_dict]
- data_dict = concatenate_episodes(ep_dicts)
- total_frames = data_dict["frame_index"].shape[0]
- data_dict["index"] = torch.arange(0, total_frames, 1)
- return data_dict
-
-
-def to_hf_dataset(data_dict, video) -> Dataset:
- features = {}
- if video:
- features["observation.image"] = VideoFrame()
- else:
- features["observation.image"] = Image()
-
- features["episode_index"] = Value(dtype="int64", id=None)
- features["frame_index"] = Value(dtype="int64", id=None)
- features["timestamp"] = Value(dtype="float32", id=None)
- features["index"] = Value(dtype="int64", id=None)
-
- hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
- hf_dataset.set_transform(hf_transform_to_torch)
- return hf_dataset
-
-
-def from_raw_to_lerobot_format(
- raw_dir: Path,
- videos_dir: Path,
- fps: int | None = None,
- video: bool = True,
- episodes: list[int] | None = None,
- encoding: dict | None = None,
-):
- if video or episodes or encoding is not None:
- # TODO(aliberts): support this
- raise NotImplementedError
-
- # sanity check
- check_format(raw_dir)
-
- if fps is None:
- fps = 30
-
- data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
- hf_dataset = to_hf_dataset(data_dict, video)
- episode_data_index = calculate_episode_data_index(hf_dataset)
- info = {
- "codebase_version": CODEBASE_VERSION,
- "fps": fps,
- "video": video,
- }
- return hf_dataset, episode_data_index, info
diff --git a/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py
deleted file mode 100644
index 4968e002..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py
+++ /dev/null
@@ -1,233 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Contains utilities to process raw data format from dora-record
-"""
-
-import re
-import warnings
-from pathlib import Path
-
-import pandas as pd
-import torch
-from datasets import Dataset, Features, Image, Sequence, Value
-
-from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
-from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
-from lerobot.common.datasets.utils import (
- hf_transform_to_torch,
-)
-from lerobot.common.datasets.video_utils import VideoFrame
-
-
-def check_format(raw_dir) -> bool:
- assert raw_dir.exists()
-
- leader_file = list(raw_dir.glob("*.parquet"))
- if len(leader_file) == 0:
- raise ValueError(f"Missing parquet files in '{raw_dir}'")
- return True
-
-
-def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
- # Load data stream that will be used as reference for the timestamps synchronization
- reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
- if len(reference_files) == 0:
- raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
- # select first camera in alphanumeric order
- reference_key = sorted(reference_files)[0].stem
- reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
- reference_df = reference_df[["timestamp_utc", reference_key]]
-
- # Merge all data stream using nearest backward strategy
- df = reference_df
- for path in raw_dir.glob("*.parquet"):
- key = path.stem # action or observation.state or ...
- if key == reference_key:
- continue
- if "failed_episode_index" in key:
- # TODO(rcadene): add support for removing episodes that are tagged as "failed"
- continue
- modality_df = pd.read_parquet(path)
- modality_df = modality_df[["timestamp_utc", key]]
- df = pd.merge_asof(
- df,
- modality_df,
- on="timestamp_utc",
- # "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
- # matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
- # However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
- # are too far appart.
- direction="nearest",
- tolerance=pd.Timedelta(f"{1 / fps} seconds"),
- )
- # Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
- df = df[df["episode_index"] != -1]
-
- image_keys = [key for key in df if "observation.images." in key]
-
- def get_episode_index(row):
- episode_index_per_cam = {}
- for key in image_keys:
- path = row[key][0]["path"]
- match = re.search(r"_(\d{6}).mp4", path)
- if not match:
- raise ValueError(path)
- episode_index = int(match.group(1))
- episode_index_per_cam[key] = episode_index
- if len(set(episode_index_per_cam.values())) != 1:
- raise ValueError(
- f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
- )
- return episode_index
-
- df["episode_index"] = df.apply(get_episode_index, axis=1)
-
- # dora only use arrays, so single values are encapsulated into a list
- df["frame_index"] = df.groupby("episode_index").cumcount()
- df = df.reset_index()
- df["index"] = df.index
-
- # set 'next.done' to True for the last frame of each episode
- df["next.done"] = False
- df.loc[df.groupby("episode_index").tail(1).index, "next.done"] = True
-
- df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
- # each episode starts with timestamp 0 to match the ones from the video
- df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
-
- del df["timestamp_utc"]
-
- # sanity check
- has_nan = df.isna().any().any()
- if has_nan:
- raise ValueError("Dataset contains Nan values.")
-
- # sanity check episode indices go from 0 to n-1
- ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
- expected_ep_ids = list(range(df["episode_index"].max() + 1))
- if ep_ids != expected_ep_ids:
- raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
-
- # Create symlink to raw videos directory (that needs to be absolute not relative)
- videos_dir.parent.mkdir(parents=True, exist_ok=True)
- videos_dir.symlink_to((raw_dir / "videos").absolute())
-
- # sanity check the video paths are well formated
- for key in df:
- if "observation.images." not in key:
- continue
- for ep_idx in ep_ids:
- video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4"
- if not video_path.exists():
- raise ValueError(f"Video file not found in {video_path}")
-
- data_dict = {}
- for key in df:
- # is video frame
- if "observation.images." in key:
- # we need `[0] because dora only use arrays, so single values are encapsulated into a list.
- # it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
- data_dict[key] = [video_frame[0] for video_frame in df[key].values]
-
- # sanity check the video path is well formated
- video_path = videos_dir.parent / data_dict[key][0]["path"]
- if not video_path.exists():
- raise ValueError(f"Video file not found in {video_path}")
- # is number
- elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1:
- data_dict[key] = torch.from_numpy(df[key].values)
- # is vector
- elif df[key].iloc[0].shape[0] > 1:
- data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
- else:
- raise ValueError(key)
-
- return data_dict
-
-
-def to_hf_dataset(data_dict, video) -> Dataset:
- features = {}
-
- keys = [key for key in data_dict if "observation.images." in key]
- for key in keys:
- if video:
- features[key] = VideoFrame()
- else:
- features[key] = Image()
-
- features["observation.state"] = Sequence(
- length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
- )
- if "observation.velocity" in data_dict:
- features["observation.velocity"] = Sequence(
- length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
- )
- if "observation.effort" in data_dict:
- features["observation.effort"] = Sequence(
- length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
- )
- features["action"] = Sequence(
- length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
- )
- features["episode_index"] = Value(dtype="int64", id=None)
- features["frame_index"] = Value(dtype="int64", id=None)
- features["timestamp"] = Value(dtype="float32", id=None)
- features["next.done"] = Value(dtype="bool", id=None)
- features["index"] = Value(dtype="int64", id=None)
-
- hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
- hf_dataset.set_transform(hf_transform_to_torch)
- return hf_dataset
-
-
-def from_raw_to_lerobot_format(
- raw_dir: Path,
- videos_dir: Path,
- fps: int | None = None,
- video: bool = True,
- episodes: list[int] | None = None,
- encoding: dict | None = None,
-):
- # sanity check
- check_format(raw_dir)
-
- if fps is None:
- fps = 30
- else:
- raise NotImplementedError()
-
- if not video:
- raise NotImplementedError()
-
- if encoding is not None:
- warnings.warn(
- "Video encoding is currently done outside of LeRobot for the dora_parquet format.",
- stacklevel=1,
- )
-
- data_df = load_from_raw(raw_dir, videos_dir, fps, episodes)
- hf_dataset = to_hf_dataset(data_df, video)
- episode_data_index = calculate_episode_data_index(hf_dataset)
- info = {
- "codebase_version": CODEBASE_VERSION,
- "fps": fps,
- "video": video,
- }
- if video:
- info["encoding"] = "unknown"
-
- return hf_dataset, episode_data_index, info
diff --git a/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py b/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py
deleted file mode 100644
index 1f8a5d14..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py
+++ /dev/null
@@ -1,312 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-For all datasets in the RLDS format.
-For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
-
-NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
-
-Example:
- python lerobot/scripts/push_dataset_to_hub.py \
- --raw-dir /path/to/data/bridge_dataset/1.0.0/ \
- --repo-id your_hub/sampled_bridge_data_v2 \
- --raw-format rlds \
- --episodes 3 4 5 8 9
-
-Exact dataset fps defined in openx/config.py, obtained from:
- https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit?gid=0#gid=0&range=R:R
-"""
-
-import shutil
-from pathlib import Path
-
-import numpy as np
-import tensorflow as tf
-import tensorflow_datasets as tfds
-import torch
-import tqdm
-from datasets import Dataset, Features, Image, Sequence, Value
-from PIL import Image as PILImage
-
-from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
-from lerobot.common.datasets.push_dataset_to_hub.utils import (
- calculate_episode_data_index,
- concatenate_episodes,
- get_default_encoding,
- save_images_concurrently,
-)
-from lerobot.common.datasets.utils import (
- hf_transform_to_torch,
-)
-from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
-
-np.set_printoptions(precision=2)
-
-
-def tf_to_torch(data):
- return torch.from_numpy(data.numpy())
-
-
-def tf_img_convert(img):
- if img.dtype == tf.string:
- img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8)
- elif img.dtype != tf.uint8:
- raise ValueError(f"Unsupported image dtype: found with dtype {img.dtype}")
- return img.numpy()
-
-
-def _broadcast_metadata_rlds(i: tf.Tensor, traj: dict) -> dict:
- """
- In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
- entry. This function moves the "steps" entry to the top level, broadcasting any metadata to the length of the
- trajectory. This function also adds the extra metadata fields `_len`, `_traj_index`, and `_frame_index`.
-
- NOTE: adapted from DLimp library https://github.com/kvablack/dlimp/
- """
- steps = traj.pop("steps")
-
- traj_len = tf.shape(tf.nest.flatten(steps)[0])[0]
-
- # broadcast metadata to the length of the trajectory
- metadata = tf.nest.map_structure(lambda x: tf.repeat(x, traj_len), traj)
-
- # put steps back in
- assert "traj_metadata" not in steps
- traj = {**steps, "traj_metadata": metadata}
-
- assert "_len" not in traj
- assert "_traj_index" not in traj
- assert "_frame_index" not in traj
- traj["_len"] = tf.repeat(traj_len, traj_len)
- traj["_traj_index"] = tf.repeat(i, traj_len)
- traj["_frame_index"] = tf.range(traj_len)
-
- return traj
-
-
-def load_from_raw(
- raw_dir: Path,
- videos_dir: Path,
- fps: int,
- video: bool,
- episodes: list[int] | None = None,
- encoding: dict | None = None,
-):
- """
- Args:
- raw_dir (Path): _description_
- videos_dir (Path): _description_
- fps (int): _description_
- video (bool): _description_
- episodes (list[int] | None, optional): _description_. Defaults to None.
- """
- ds_builder = tfds.builder_from_directory(str(raw_dir))
- dataset = ds_builder.as_dataset(
- split="all",
- decoders={"steps": tfds.decode.SkipDecoding()},
- )
-
- dataset_info = ds_builder.info
- print("dataset_info: ", dataset_info)
-
- ds_length = len(dataset)
- dataset = dataset.take(ds_length)
- # "flatten" the dataset as such we can apply trajectory level map() easily
- # each [obs][key] has a shape of (frame_size, ...)
- dataset = dataset.enumerate().map(_broadcast_metadata_rlds)
-
- # we will apply the standardization transform if the dataset_name is provided
- # if the dataset name is not provided and the goal is to convert any rlds formatted dataset
- # search for 'image' keys in the observations
- image_keys = []
- state_keys = []
- observation_info = dataset_info.features["steps"]["observation"]
- for key in observation_info:
- # check whether the key is for an image or a vector observation
- if len(observation_info[key].shape) == 3:
- # only adding uint8 images discards depth images
- if observation_info[key].dtype == tf.uint8:
- image_keys.append(key)
- else:
- state_keys.append(key)
-
- lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
-
- print(" - image_keys: ", image_keys)
- print(" - lang_key: ", lang_key)
-
- it = iter(dataset)
-
- ep_dicts = []
- # Init temp path to save ep_dicts in case of crash
- tmp_ep_dicts_dir = videos_dir.parent.joinpath("ep_dicts")
- tmp_ep_dicts_dir.mkdir(parents=True, exist_ok=True)
-
- # check if ep_dicts have already been saved in /tmp
- starting_ep_idx = 0
- saved_ep_dicts = [ep.__str__() for ep in tmp_ep_dicts_dir.iterdir()]
- if len(saved_ep_dicts) > 0:
- saved_ep_dicts.sort()
- # get last ep_idx number
- starting_ep_idx = int(saved_ep_dicts[-1][-13:-3]) + 1
- for i in range(starting_ep_idx):
- episode = next(it)
- ep_dicts.append(torch.load(saved_ep_dicts[i]))
-
- # if we user specified episodes, skip the ones not in the list
- if episodes is not None:
- if ds_length == 0:
- raise ValueError("No episodes found.")
- # convert episodes index to sorted list
- episodes = sorted(episodes)
-
- for ep_idx in tqdm.tqdm(range(starting_ep_idx, ds_length)):
- episode = next(it)
-
- # if user specified episodes, skip the ones not in the list
- if episodes is not None:
- if len(episodes) == 0:
- break
- if ep_idx == episodes[0]:
- # process this episode
- print(" selecting episode idx: ", ep_idx)
- episodes.pop(0)
- else:
- continue # skip
-
- num_frames = episode["action"].shape[0]
-
- ep_dict = {}
- for key in state_keys:
- ep_dict[f"observation.{key}"] = tf_to_torch(episode["observation"][key])
-
- ep_dict["action"] = tf_to_torch(episode["action"])
- ep_dict["next.reward"] = tf_to_torch(episode["reward"]).float()
- ep_dict["next.done"] = tf_to_torch(episode["is_last"])
- ep_dict["is_terminal"] = tf_to_torch(episode["is_terminal"])
- ep_dict["is_first"] = tf_to_torch(episode["is_first"])
- ep_dict["discount"] = tf_to_torch(episode["discount"])
-
- # If lang_key is present, convert the entire tensor at once
- if lang_key is not None:
- ep_dict["language_instruction"] = [x.numpy().decode("utf-8") for x in episode[lang_key]]
-
- ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
- ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
- ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
-
- image_array_dict = {key: [] for key in image_keys}
-
- for im_key in image_keys:
- imgs = episode["observation"][im_key]
- image_array_dict[im_key] = [tf_img_convert(img) for img in imgs]
-
- # loop through all cameras
- for im_key in image_keys:
- img_key = f"observation.images.{im_key}"
- imgs_array = image_array_dict[im_key]
- imgs_array = np.array(imgs_array)
- if video:
- # save png images in temporary directory
- tmp_imgs_dir = videos_dir / "tmp_images"
- save_images_concurrently(imgs_array, tmp_imgs_dir)
-
- # encode images to a mp4 video
- fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
- video_path = videos_dir / fname
- encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
-
- # clean temporary images directory
- shutil.rmtree(tmp_imgs_dir)
-
- # store the reference to the video frame
- ep_dict[img_key] = [
- {"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
- ]
- else:
- ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
-
- path_ep_dict = tmp_ep_dicts_dir.joinpath(
- "ep_dict_" + "0" * (10 - len(str(ep_idx))) + str(ep_idx) + ".pt"
- )
- torch.save(ep_dict, path_ep_dict)
-
- ep_dicts.append(ep_dict)
-
- data_dict = concatenate_episodes(ep_dicts)
-
- total_frames = data_dict["frame_index"].shape[0]
- data_dict["index"] = torch.arange(0, total_frames, 1)
- return data_dict
-
-
-def to_hf_dataset(data_dict, video) -> Dataset:
- features = {}
-
- for key in data_dict:
- # check if vector state obs
- if key.startswith("observation.") and "observation.images." not in key:
- features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
- # check if image obs
- elif "observation.images." in key:
- if video:
- features[key] = VideoFrame()
- else:
- features[key] = Image()
-
- if "language_instruction" in data_dict:
- features["language_instruction"] = Value(dtype="string", id=None)
-
- features["action"] = Sequence(
- length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
- )
-
- features["is_terminal"] = Value(dtype="bool", id=None)
- features["is_first"] = Value(dtype="bool", id=None)
- features["discount"] = Value(dtype="float32", id=None)
-
- features["episode_index"] = Value(dtype="int64", id=None)
- features["frame_index"] = Value(dtype="int64", id=None)
- features["timestamp"] = Value(dtype="float32", id=None)
- features["next.reward"] = Value(dtype="float32", id=None)
- features["next.done"] = Value(dtype="bool", id=None)
- features["index"] = Value(dtype="int64", id=None)
-
- hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
- hf_dataset.set_transform(hf_transform_to_torch)
- return hf_dataset
-
-
-def from_raw_to_lerobot_format(
- raw_dir: Path,
- videos_dir: Path,
- fps: int | None = None,
- video: bool = True,
- episodes: list[int] | None = None,
- encoding: dict | None = None,
-):
- data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
- hf_dataset = to_hf_dataset(data_dict, video)
- episode_data_index = calculate_episode_data_index(hf_dataset)
- info = {
- "codebase_version": CODEBASE_VERSION,
- "fps": fps,
- "video": video,
- }
- if video:
- info["encoding"] = get_default_encoding()
-
- return hf_dataset, episode_data_index, info
diff --git a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py
deleted file mode 100644
index 27b31ba2..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py
+++ /dev/null
@@ -1,275 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
-
-import shutil
-from pathlib import Path
-
-import numpy as np
-import torch
-import tqdm
-import zarr
-from datasets import Dataset, Features, Image, Sequence, Value
-from PIL import Image as PILImage
-
-from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
-from lerobot.common.datasets.push_dataset_to_hub.utils import (
- calculate_episode_data_index,
- concatenate_episodes,
- get_default_encoding,
- save_images_concurrently,
-)
-from lerobot.common.datasets.utils import (
- hf_transform_to_torch,
-)
-from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
-
-
-def check_format(raw_dir):
- zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
- zarr_data = zarr.open(zarr_path, mode="r")
-
- required_datasets = {
- "data/action",
- "data/img",
- "data/keypoint",
- "data/n_contacts",
- "data/state",
- "meta/episode_ends",
- }
- for dataset in required_datasets:
- assert dataset in zarr_data
- nb_frames = zarr_data["data/img"].shape[0]
-
- required_datasets.remove("meta/episode_ends")
-
- assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
-
-
-def load_from_raw(
- raw_dir: Path,
- videos_dir: Path,
- fps: int,
- video: bool,
- episodes: list[int] | None = None,
- keypoints_instead_of_image: bool = False,
- encoding: dict | None = None,
-):
- try:
- import pymunk
- from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
-
- from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
- ReplayBuffer as DiffusionPolicyReplayBuffer,
- )
- except ModuleNotFoundError as e:
- print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
- raise e
- # as define in gmy-pusht env: https://github.com/huggingface/gym-pusht/blob/e0684ff988d223808c0a9dcfaba9dc4991791370/gym_pusht/envs/pusht.py#L174
- success_threshold = 0.95 # 95% coverage,
-
- zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
- zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
-
- episode_ids = torch.from_numpy(zarr_data.get_episode_idxs())
- assert len(
- {zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118
- ), "Some data type dont have the same number of total frames."
-
- # TODO(rcadene): verify that goal pose is expected to be fixed
- goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
- goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
-
- imgs = torch.from_numpy(zarr_data["img"]) # b h w c
- states = torch.from_numpy(zarr_data["state"])
- actions = torch.from_numpy(zarr_data["action"])
-
- # load data indices from which each episode starts and ends
- from_ids, to_ids = [], []
- from_idx = 0
- for to_idx in zarr_data.meta["episode_ends"]:
- from_ids.append(from_idx)
- to_ids.append(to_idx)
- from_idx = to_idx
-
- num_episodes = len(from_ids)
-
- ep_dicts = []
- ep_ids = episodes if episodes else range(num_episodes)
- for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
- from_idx = from_ids[selected_ep_idx]
- to_idx = to_ids[selected_ep_idx]
- num_frames = to_idx - from_idx
-
- # sanity check
- assert (episode_ids[from_idx:to_idx] == ep_idx).all()
-
- # get image
- if not keypoints_instead_of_image:
- image = imgs[from_idx:to_idx]
- assert image.min() >= 0.0
- assert image.max() <= 255.0
- image = image.type(torch.uint8)
-
- # get state
- state = states[from_idx:to_idx]
- agent_pos = state[:, :2]
- block_pos = state[:, 2:4]
- block_angle = state[:, 4]
-
- # get reward, success, done, and (maybe) keypoints
- reward = torch.zeros(num_frames)
- success = torch.zeros(num_frames, dtype=torch.bool)
- if keypoints_instead_of_image:
- keypoints = torch.zeros(num_frames, 16) # 8 keypoints each with 2 coords
- done = torch.zeros(num_frames, dtype=torch.bool)
- for i in range(num_frames):
- space = pymunk.Space()
- space.gravity = 0, 0
- space.damping = 0
-
- # Add walls.
- walls = [
- PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
- PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
- PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
- PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
- ]
- space.add(*walls)
-
- block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
- goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
- block_geom = pymunk_to_shapely(block_body, block_body.shapes)
- intersection_area = goal_geom.intersection(block_geom).area
- goal_area = goal_geom.area
- coverage = intersection_area / goal_area
- reward[i] = np.clip(coverage / success_threshold, 0, 1)
- success[i] = coverage > success_threshold
- if keypoints_instead_of_image:
- keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
-
- # last step of demonstration is considered done
- done[-1] = True
-
- ep_dict = {}
-
- if not keypoints_instead_of_image:
- imgs_array = [x.numpy() for x in image]
- img_key = "observation.image"
- if video:
- # save png images in temporary directory
- tmp_imgs_dir = videos_dir / "tmp_images"
- save_images_concurrently(imgs_array, tmp_imgs_dir)
-
- # encode images to a mp4 video
- fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
- video_path = videos_dir / fname
- encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
-
- # clean temporary images directory
- shutil.rmtree(tmp_imgs_dir)
-
- # store the reference to the video frame
- ep_dict[img_key] = [
- {"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
- ]
- else:
- ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
-
- ep_dict["observation.state"] = agent_pos
- if keypoints_instead_of_image:
- ep_dict["observation.environment_state"] = keypoints
- ep_dict["action"] = actions[from_idx:to_idx]
- ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
- ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
- ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
- # ep_dict["next.observation.image"] = image[1:],
- # ep_dict["next.observation.state"] = agent_pos[1:],
- # TODO(rcadene)] = verify that reward and done are aligned with image and agent_pos
- ep_dict["next.reward"] = torch.cat([reward[1:], reward[[-1]]])
- ep_dict["next.done"] = torch.cat([done[1:], done[[-1]]])
- ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
- ep_dicts.append(ep_dict)
- data_dict = concatenate_episodes(ep_dicts)
-
- total_frames = data_dict["frame_index"].shape[0]
- data_dict["index"] = torch.arange(0, total_frames, 1)
- return data_dict
-
-
-def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
- features = {}
-
- if not keypoints_instead_of_image:
- if video:
- features["observation.image"] = VideoFrame()
- else:
- features["observation.image"] = Image()
-
- features["observation.state"] = Sequence(
- length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
- )
- if keypoints_instead_of_image:
- features["observation.environment_state"] = Sequence(
- length=data_dict["observation.environment_state"].shape[1],
- feature=Value(dtype="float32", id=None),
- )
- features["action"] = Sequence(
- length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
- )
- features["episode_index"] = Value(dtype="int64", id=None)
- features["frame_index"] = Value(dtype="int64", id=None)
- features["timestamp"] = Value(dtype="float32", id=None)
- features["next.reward"] = Value(dtype="float32", id=None)
- features["next.done"] = Value(dtype="bool", id=None)
- features["next.success"] = Value(dtype="bool", id=None)
- features["index"] = Value(dtype="int64", id=None)
-
- hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
- hf_dataset.set_transform(hf_transform_to_torch)
- return hf_dataset
-
-
-def from_raw_to_lerobot_format(
- raw_dir: Path,
- videos_dir: Path,
- fps: int | None = None,
- video: bool = True,
- episodes: list[int] | None = None,
- encoding: dict | None = None,
-):
- # Manually change this to True to use keypoints of the T instead of an image observation (but don't merge
- # with True). Also make sure to use video = 0 in the `push_dataset_to_hub.py` script.
- keypoints_instead_of_image = False
-
- # sanity check
- check_format(raw_dir)
-
- if fps is None:
- fps = 10
-
- data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding)
- hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
- episode_data_index = calculate_episode_data_index(hf_dataset)
- info = {
- "codebase_version": CODEBASE_VERSION,
- "fps": fps,
- "video": video if not keypoints_instead_of_image else 0,
- }
- if video:
- info["encoding"] = get_default_encoding()
-
- return hf_dataset, episode_data_index, info
diff --git a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py
deleted file mode 100644
index fec893a7..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py
+++ /dev/null
@@ -1,234 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface"""
-
-import logging
-import shutil
-from pathlib import Path
-
-import torch
-import tqdm
-import zarr
-from datasets import Dataset, Features, Image, Sequence, Value
-from PIL import Image as PILImage
-
-from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
-from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
-from lerobot.common.datasets.push_dataset_to_hub.utils import (
- calculate_episode_data_index,
- concatenate_episodes,
- get_default_encoding,
- save_images_concurrently,
-)
-from lerobot.common.datasets.utils import (
- hf_transform_to_torch,
-)
-from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
-
-
-def check_format(raw_dir) -> bool:
- zarr_path = raw_dir / "cup_in_the_wild.zarr"
- zarr_data = zarr.open(zarr_path, mode="r")
-
- required_datasets = {
- "data/robot0_demo_end_pose",
- "data/robot0_demo_start_pose",
- "data/robot0_eef_pos",
- "data/robot0_eef_rot_axis_angle",
- "data/robot0_gripper_width",
- "meta/episode_ends",
- "data/camera0_rgb",
- }
- for dataset in required_datasets:
- if dataset not in zarr_data:
- return False
-
- # mandatory to access zarr_data
- register_codecs()
- nb_frames = zarr_data["data/camera0_rgb"].shape[0]
-
- required_datasets.remove("meta/episode_ends")
- assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
-
-
-def load_from_raw(
- raw_dir: Path,
- videos_dir: Path,
- fps: int,
- video: bool,
- episodes: list[int] | None = None,
- encoding: dict | None = None,
-):
- zarr_path = raw_dir / "cup_in_the_wild.zarr"
- zarr_data = zarr.open(zarr_path, mode="r")
-
- # We process the image data separately because it is too large to fit in memory
- end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
- start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
- eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
- eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
- gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
-
- states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
- states = torch.cat([states_pos, gripper_width], dim=1)
-
- episode_ends = zarr_data["meta/episode_ends"][:]
- num_episodes = episode_ends.shape[0]
-
- # We convert it in torch tensor later because the jit function does not support torch tensors
- episode_ends = torch.from_numpy(episode_ends)
-
- # load data indices from which each episode starts and ends
- from_ids, to_ids = [], []
- from_idx = 0
- for to_idx in episode_ends:
- from_ids.append(from_idx)
- to_ids.append(to_idx)
- from_idx = to_idx
-
- ep_dicts_dir = videos_dir / "ep_dicts"
- ep_dicts_dir.mkdir(exist_ok=True, parents=True)
- ep_dicts = []
-
- ep_ids = episodes if episodes else range(num_episodes)
- for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
- ep_dict_path = ep_dicts_dir / f"{ep_idx}"
- if not ep_dict_path.is_file():
- from_idx = from_ids[selected_ep_idx]
- to_idx = to_ids[selected_ep_idx]
- num_frames = to_idx - from_idx
-
- # TODO(rcadene): save temporary images of the episode?
-
- state = states[from_idx:to_idx]
-
- ep_dict = {}
-
- # load 57MB of images in RAM (400x224x224x3 uint8)
- imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx]
- img_key = "observation.image"
- if video:
- fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
- video_path = videos_dir / fname
- if not video_path.is_file():
- # save png images in temporary directory
- tmp_imgs_dir = videos_dir / "tmp_images"
- save_images_concurrently(imgs_array, tmp_imgs_dir)
-
- # encode images to a mp4 video
- encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
-
- # clean temporary images directory
- shutil.rmtree(tmp_imgs_dir)
-
- # store the reference to the video frame
- ep_dict[img_key] = [
- {"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
- ]
- else:
- ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
-
- ep_dict["observation.state"] = state
- ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
- ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
- ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
- ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
- ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
- ep_dict["end_pose"] = end_pose[from_idx:to_idx]
- ep_dict["start_pos"] = start_pos[from_idx:to_idx]
- ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
- torch.save(ep_dict, ep_dict_path)
- else:
- ep_dict = torch.load(ep_dict_path)
-
- ep_dicts.append(ep_dict)
-
- data_dict = concatenate_episodes(ep_dicts)
-
- total_frames = data_dict["frame_index"].shape[0]
- data_dict["index"] = torch.arange(0, total_frames, 1)
- return data_dict
-
-
-def to_hf_dataset(data_dict, video):
- features = {}
-
- if video:
- features["observation.image"] = VideoFrame()
- else:
- features["observation.image"] = Image()
-
- features["observation.state"] = Sequence(
- length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
- )
- features["episode_index"] = Value(dtype="int64", id=None)
- features["frame_index"] = Value(dtype="int64", id=None)
- features["timestamp"] = Value(dtype="float32", id=None)
- features["index"] = Value(dtype="int64", id=None)
- features["episode_data_index_from"] = Value(dtype="int64", id=None)
- features["episode_data_index_to"] = Value(dtype="int64", id=None)
- # `start_pos` and `end_pos` respectively represent the positions of the end-effector
- # at the beginning and the end of the episode.
- # `gripper_width` indicates the distance between the grippers, and this value is included
- # in the state vector, which comprises the concatenation of the end-effector position
- # and gripper width.
- features["end_pose"] = Sequence(
- length=data_dict["end_pose"].shape[1], feature=Value(dtype="float32", id=None)
- )
- features["start_pos"] = Sequence(
- length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
- )
- features["gripper_width"] = Sequence(
- length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
- )
-
- hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
- hf_dataset.set_transform(hf_transform_to_torch)
- return hf_dataset
-
-
-def from_raw_to_lerobot_format(
- raw_dir: Path,
- videos_dir: Path,
- fps: int | None = None,
- video: bool = True,
- episodes: list[int] | None = None,
- encoding: dict | None = None,
-):
- # sanity check
- check_format(raw_dir)
-
- if fps is None:
- # For umi cup in the wild: https://arxiv.org/pdf/2402.10329#table.caption.16
- fps = 10
-
- if not video:
- logging.warning(
- "Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
- )
-
- data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
- hf_dataset = to_hf_dataset(data_dict, video)
- episode_data_index = calculate_episode_data_index(hf_dataset)
- info = {
- "codebase_version": CODEBASE_VERSION,
- "fps": fps,
- "video": video,
- }
- if video:
- info["encoding"] = get_default_encoding()
-
- return hf_dataset, episode_data_index, info
diff --git a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py
deleted file mode 100644
index 0047e48c..00000000
--- a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py
+++ /dev/null
@@ -1,200 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
-
-import pickle
-import shutil
-from pathlib import Path
-
-import einops
-import torch
-import tqdm
-from datasets import Dataset, Features, Image, Sequence, Value
-from PIL import Image as PILImage
-
-from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
-from lerobot.common.datasets.push_dataset_to_hub.utils import (
- calculate_episode_data_index,
- concatenate_episodes,
- get_default_encoding,
- save_images_concurrently,
-)
-from lerobot.common.datasets.utils import (
- hf_transform_to_torch,
-)
-from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
-
-
-def check_format(raw_dir):
- keys = {"actions", "rewards", "dones"}
- nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}}
-
- xarm_files = list(raw_dir.glob("*.pkl"))
- assert len(xarm_files) > 0
-
- with open(xarm_files[0], "rb") as f:
- dataset_dict = pickle.load(f)
-
- assert isinstance(dataset_dict, dict)
- assert all(k in dataset_dict for k in keys)
-
- # Check for consistent lengths in nested keys
- expected_len = len(dataset_dict["actions"])
- assert all(len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict)
-
- for key, subkeys in nested_keys.items():
- nested_dict = dataset_dict.get(key, {})
- assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
-
-
-def load_from_raw(
- raw_dir: Path,
- videos_dir: Path,
- fps: int,
- video: bool,
- episodes: list[int] | None = None,
- encoding: dict | None = None,
-):
- pkl_path = raw_dir / "buffer.pkl"
-
- with open(pkl_path, "rb") as f:
- pkl_data = pickle.load(f)
-
- # load data indices from which each episode starts and ends
- from_ids, to_ids = [], []
- from_idx, to_idx = 0, 0
- for done in pkl_data["dones"]:
- to_idx += 1
- if not done:
- continue
- from_ids.append(from_idx)
- to_ids.append(to_idx)
- from_idx = to_idx
-
- num_episodes = len(from_ids)
-
- ep_dicts = []
- ep_ids = episodes if episodes else range(num_episodes)
- for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
- from_idx = from_ids[selected_ep_idx]
- to_idx = to_ids[selected_ep_idx]
- num_frames = to_idx - from_idx
-
- image = torch.tensor(pkl_data["observations"]["rgb"][from_idx:to_idx])
- image = einops.rearrange(image, "b c h w -> b h w c")
- state = torch.tensor(pkl_data["observations"]["state"][from_idx:to_idx])
- action = torch.tensor(pkl_data["actions"][from_idx:to_idx])
- # TODO(rcadene): we have a missing last frame which is the observation when the env is done
- # it is critical to have this frame for tdmpc to predict a "done observation/state"
- # next_image = torch.tensor(pkl_data["next_observations"]["rgb"][from_idx:to_idx])
- # next_state = torch.tensor(pkl_data["next_observations"]["state"][from_idx:to_idx])
- next_reward = torch.tensor(pkl_data["rewards"][from_idx:to_idx])
- next_done = torch.tensor(pkl_data["dones"][from_idx:to_idx])
-
- ep_dict = {}
-
- imgs_array = [x.numpy() for x in image]
- img_key = "observation.image"
- if video:
- # save png images in temporary directory
- tmp_imgs_dir = videos_dir / "tmp_images"
- save_images_concurrently(imgs_array, tmp_imgs_dir)
-
- # encode images to a mp4 video
- fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
- video_path = videos_dir / fname
- encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
-
- # clean temporary images directory
- shutil.rmtree(tmp_imgs_dir)
-
- # store the reference to the video frame
- ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
- else:
- ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
-
- ep_dict["observation.state"] = state
- ep_dict["action"] = action
- ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
- ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
- ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
- # ep_dict["next.observation.image"] = next_image
- # ep_dict["next.observation.state"] = next_state
- ep_dict["next.reward"] = next_reward
- ep_dict["next.done"] = next_done
- ep_dicts.append(ep_dict)
-
- data_dict = concatenate_episodes(ep_dicts)
-
- total_frames = data_dict["frame_index"].shape[0]
- data_dict["index"] = torch.arange(0, total_frames, 1)
- return data_dict
-
-
-def to_hf_dataset(data_dict, video):
- features = {}
-
- if video:
- features["observation.image"] = VideoFrame()
- else:
- features["observation.image"] = Image()
-
- features["observation.state"] = Sequence(
- length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
- )
- features["action"] = Sequence(
- length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
- )
- features["episode_index"] = Value(dtype="int64", id=None)
- features["frame_index"] = Value(dtype="int64", id=None)
- features["timestamp"] = Value(dtype="float32", id=None)
- features["next.reward"] = Value(dtype="float32", id=None)
- features["next.done"] = Value(dtype="bool", id=None)
- features["index"] = Value(dtype="int64", id=None)
- # TODO(rcadene): add success
- # features["next.success"] = Value(dtype='bool', id=None)
-
- hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
- hf_dataset.set_transform(hf_transform_to_torch)
- return hf_dataset
-
-
-def from_raw_to_lerobot_format(
- raw_dir: Path,
- videos_dir: Path,
- fps: int | None = None,
- video: bool = True,
- episodes: list[int] | None = None,
- encoding: dict | None = None,
-):
- # sanity check
- check_format(raw_dir)
-
- if fps is None:
- fps = 15
-
- data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
- hf_dataset = to_hf_dataset(data_dict, video)
- episode_data_index = calculate_episode_data_index(hf_dataset)
- info = {
- "codebase_version": CODEBASE_VERSION,
- "fps": fps,
- "video": video,
- }
- if video:
- info["encoding"] = get_default_encoding()
-
- return hf_dataset, episode_data_index, info
diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py
index 612bac39..7e297b35 100644
--- a/lerobot/common/datasets/utils.py
+++ b/lerobot/common/datasets/utils.py
@@ -13,10 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import contextlib
import importlib.resources
import json
import logging
-import textwrap
from collections.abc import Iterator
from itertools import accumulate
from pathlib import Path
@@ -27,14 +27,21 @@ from typing import Any
import datasets
import jsonlines
import numpy as np
-import pyarrow.compute as pc
+import packaging.version
import torch
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
+from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
+from lerobot.common.datasets.backward_compatibility import (
+ V21_MESSAGE,
+ BackwardCompatibilityError,
+ ForwardCompatibilityError,
+)
from lerobot.common.robot_devices.robots.utils import Robot
+from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
@@ -42,6 +49,7 @@ DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
INFO_PATH = "meta/info.json"
EPISODES_PATH = "meta/episodes.jsonl"
STATS_PATH = "meta/stats.json"
+EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
@@ -112,17 +120,26 @@ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
- serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
+ serialized_dict = {}
+ for key, value in flatten_dict(stats).items():
+ if isinstance(value, (torch.Tensor, np.ndarray)):
+ serialized_dict[key] = value.tolist()
+ elif isinstance(value, np.generic):
+ serialized_dict[key] = value.item()
+ elif isinstance(value, (int, float)):
+ serialized_dict[key] = value
+ else:
+ raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
return unflatten_dict(serialized_dict)
-def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
+def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
# Embed image bytes into the table before saving to parquet
format = dataset.format
dataset = dataset.with_format("arrow")
dataset = dataset.map(embed_table_storage, batched=False)
dataset = dataset.with_format(**format)
- dataset.to_parquet(fpath)
+ return dataset
def load_json(fpath: Path) -> Any:
@@ -153,6 +170,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
writer.write(data)
+def write_info(info: dict, local_dir: Path):
+ write_json(info, local_dir / INFO_PATH)
+
+
def load_info(local_dir: Path) -> dict:
info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values():
@@ -160,29 +181,76 @@ def load_info(local_dir: Path) -> dict:
return info
-def load_stats(local_dir: Path) -> dict:
- if not (local_dir / STATS_PATH).exists():
- return None
- stats = load_json(local_dir / STATS_PATH)
- stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
+def write_stats(stats: dict, local_dir: Path):
+ serialized_stats = serialize_dict(stats)
+ write_json(serialized_stats, local_dir / STATS_PATH)
+
+
+def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
+ stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
-def load_tasks(local_dir: Path) -> dict:
+def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
+ if not (local_dir / STATS_PATH).exists():
+ return None
+ stats = load_json(local_dir / STATS_PATH)
+ return cast_stats_to_numpy(stats)
+
+
+def write_task(task_index: int, task: dict, local_dir: Path):
+ task_dict = {
+ "task_index": task_index,
+ "task": task,
+ }
+ append_jsonlines(task_dict, local_dir / TASKS_PATH)
+
+
+def load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / TASKS_PATH)
- return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
+ tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
+ task_to_task_index = {task: task_index for task_index, task in tasks.items()}
+ return tasks, task_to_task_index
+
+
+def write_episode(episode: dict, local_dir: Path):
+ append_jsonlines(episode, local_dir / EPISODES_PATH)
def load_episodes(local_dir: Path) -> dict:
- return load_jsonlines(local_dir / EPISODES_PATH)
+ episodes = load_jsonlines(local_dir / EPISODES_PATH)
+ return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
-def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
+def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
+ # We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
+ # is a dictionary of stats and not an integer.
+ episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
+ append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
+
+
+def load_episodes_stats(local_dir: Path) -> dict:
+ episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
+ return {
+ item["episode_index"]: cast_stats_to_numpy(item["stats"])
+ for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
+ }
+
+
+def backward_compatible_episodes_stats(
+ stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
+) -> dict[str, dict[str, np.ndarray]]:
+ return {ep_idx: stats for ep_idx in episodes}
+
+
+def load_image_as_numpy(
+ fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
+) -> np.ndarray:
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W)
img_array = np.transpose(img_array, (2, 0, 1))
- if "float" in dtype:
+ if np.issubdtype(dtype, np.floating):
img_array /= 255.0
return img_array
@@ -201,77 +269,95 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
elif first_item is None:
pass
else:
- items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
+ items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
return items_dict
-def _get_major_minor(version: str) -> tuple[int]:
- split = version.strip("v").split(".")
- return int(split[0]), int(split[1])
-
-
-class BackwardCompatibilityError(Exception):
- def __init__(self, repo_id, version):
- message = textwrap.dedent(f"""
- BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
-
- We introduced a new format since v2.0 which is not backward compatible with v1.x.
- Please, use our conversion script. Modify the following command with your own task description:
- ```
- python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
- --repo-id {repo_id} \\
- --single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
- ```
-
- A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
- "Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
- "Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
- "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
-
- If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
- or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
- """)
- super().__init__(message)
+def is_valid_version(version: str) -> bool:
+ try:
+ packaging.version.parse(version)
+ return True
+ except packaging.version.InvalidVersion:
+ return False
def check_version_compatibility(
- repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
+ repo_id: str,
+ version_to_check: str | packaging.version.Version,
+ current_version: str | packaging.version.Version,
+ enforce_breaking_major: bool = True,
) -> None:
- current_major, _ = _get_major_minor(current_version)
- major_to_check, _ = _get_major_minor(version_to_check)
- if major_to_check < current_major and enforce_breaking_major:
- raise BackwardCompatibilityError(repo_id, version_to_check)
- elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
- logging.warning(
- f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
- codebase. The current codebase version is {current_version}. You should be fine since
- backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
- Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
- )
+ v_check = (
+ packaging.version.parse(version_to_check)
+ if not isinstance(version_to_check, packaging.version.Version)
+ else version_to_check
+ )
+ v_current = (
+ packaging.version.parse(current_version)
+ if not isinstance(current_version, packaging.version.Version)
+ else current_version
+ )
+ if v_check.major < v_current.major and enforce_breaking_major:
+ raise BackwardCompatibilityError(repo_id, v_check)
+ elif v_check.minor < v_current.minor:
+ logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
-def get_hub_safe_version(repo_id: str, version: str) -> str:
+def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
+ """Returns available valid versions (branches and tags) on given repo."""
api = HfApi()
- dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
- branches = [b.name for b in dataset_info.branches]
- if version not in branches:
- num_version = float(version.strip("v"))
- hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
- if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
- raise BackwardCompatibilityError(repo_id, version)
+ repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
+ repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
+ repo_versions = []
+ for ref in repo_refs:
+ with contextlib.suppress(packaging.version.InvalidVersion):
+ repo_versions.append(packaging.version.parse(ref))
- logging.warning(
- f"""You are trying to load a dataset from {repo_id} created with a previous version of the
- codebase. The following versions are available: {branches}.
- The requested version ('{version}') is not found. You should be fine since
- backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
- Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
+ return repo_versions
+
+
+def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
+ """
+ Returns the version if available on repo or the latest compatible one.
+ Otherwise, will throw a `CompatibilityError`.
+ """
+ target_version = (
+ packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
+ )
+ hub_versions = get_repo_versions(repo_id)
+
+ if not hub_versions:
+ raise RevisionNotFoundError(
+ f"""Your dataset must be tagged with a codebase version.
+ Assuming _version_ is the codebase_version value in the info.json, you can run this:
+ ```python
+ from huggingface_hub import HfApi
+
+ hub_api = HfApi()
+ hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
+ ```
+ """
)
- if "main" not in branches:
- raise ValueError(f"Version 'main' not found on {repo_id}")
- return "main"
- else:
- return version
+
+ if target_version in hub_versions:
+ return f"v{target_version}"
+
+ compatibles = [
+ v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
+ ]
+ if compatibles:
+ return_version = max(compatibles)
+ if return_version < target_version:
+ logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
+ return f"v{return_version}"
+
+ lower_major = [v for v in hub_versions if v.major < target_version.major]
+ if lower_major:
+ raise BackwardCompatibilityError(repo_id, max(lower_major))
+
+ upper_versions = [v for v in hub_versions if v > target_version]
+ assert len(upper_versions) > 0
+ raise ForwardCompatibilityError(repo_id, min(upper_versions))
def get_hf_features_from_features(features: dict) -> datasets.Features:
@@ -283,11 +369,20 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features[key] = datasets.Image()
elif ft["shape"] == (1,):
hf_features[key] = datasets.Value(dtype=ft["dtype"])
- else:
- assert len(ft["shape"]) == 1
+ elif len(ft["shape"]) == 1:
hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
)
+ elif len(ft["shape"]) == 2:
+ hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
+ elif len(ft["shape"]) == 3:
+ hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
+ elif len(ft["shape"]) == 4:
+ hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
+ elif len(ft["shape"]) == 5:
+ hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
+ else:
+ raise ValueError(f"Corresponding feature is not valid: {ft}")
return datasets.Features(hf_features)
@@ -358,88 +453,85 @@ def create_empty_dataset_info(
def get_episode_data_index(
- episode_dicts: list[dict], episodes: list[int] | None = None
+ episode_dicts: dict[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]:
- episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
+ episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
- cumulative_lenghts = list(accumulate(episode_lengths.values()))
+ cumulative_lengths = list(accumulate(episode_lengths.values()))
return {
- "from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
- "to": torch.LongTensor(cumulative_lenghts),
- }
-
-
-def calculate_total_episode(
- hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
-) -> dict[str, torch.Tensor]:
- episode_indices = sorted(hf_dataset.unique("episode_index"))
- total_episodes = len(episode_indices)
- if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
- raise ValueError("episode_index values are not sorted and contiguous.")
- return total_episodes
-
-
-def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
- episode_lengths = []
- table = hf_dataset.data.table
- total_episodes = calculate_total_episode(hf_dataset)
- for ep_idx in range(total_episodes):
- ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
- episode_lengths.insert(ep_idx, len(ep_table))
-
- cumulative_lenghts = list(accumulate(episode_lengths))
- return {
- "from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
- "to": torch.LongTensor(cumulative_lenghts),
+ "from": torch.LongTensor([0] + cumulative_lengths[:-1]),
+ "to": torch.LongTensor(cumulative_lengths),
}
def check_timestamps_sync(
- hf_dataset: datasets.Dataset,
- episode_data_index: dict[str, torch.Tensor],
+ timestamps: np.ndarray,
+ episode_indices: np.ndarray,
+ episode_data_index: dict[str, np.ndarray],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""
- This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
- account for possible numerical error.
- """
- timestamps = torch.stack(hf_dataset["timestamp"])
- diffs = torch.diff(timestamps)
- within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
+ This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
+ to account for possible numerical error.
- # We mask differences between the timestamp at the end of an episode
- # and the one at the start of the next episode since these are expected
- # to be outside tolerance.
- mask = torch.ones(len(diffs), dtype=torch.bool)
- ignored_diffs = episode_data_index["to"][:-1] - 1
+ Args:
+ timestamps (np.ndarray): Array of timestamps in seconds.
+ episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
+ episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
+ which identifies indices for the end of each episode.
+ fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
+ tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
+ raise_value_error (bool): Whether to raise a ValueError if the check fails.
+
+ Returns:
+ bool: True if all checked timestamp differences lie within tolerance, False otherwise.
+
+ Raises:
+ ValueError: If the check fails and `raise_value_error` is True.
+ """
+ if timestamps.shape != episode_indices.shape:
+ raise ValueError(
+ "timestamps and episode_indices should have the same shape. "
+ f"Found {timestamps.shape=} and {episode_indices.shape=}."
+ )
+
+ # Consecutive differences
+ diffs = np.diff(timestamps)
+ within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
+
+ # Mask to ignore differences at the boundaries between episodes
+ mask = np.ones(len(diffs), dtype=bool)
+ ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask]
- if not torch.all(filtered_within_tolerance):
+ # Check if all remaining diffs are within tolerance
+ if not np.all(filtered_within_tolerance):
# Track original indices before masking
- original_indices = torch.arange(len(diffs))
+ original_indices = np.arange(len(diffs))
filtered_indices = original_indices[mask]
- outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
+ outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
- episode_indices = torch.stack(hf_dataset["episode_index"])
outside_tolerances = []
for idx in outside_tolerance_indices:
entry = {
"timestamps": [timestamps[idx], timestamps[idx + 1]],
"diff": diffs[idx],
- "episode_index": episode_indices[idx].item(),
+ "episode_index": episode_indices[idx].item()
+ if hasattr(episode_indices[idx], "item")
+ else episode_indices[idx],
}
outside_tolerances.append(entry)
if raise_value_error:
raise ValueError(
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
- This might be due to synchronization issues with timestamps during data collection.
+ This might be due to synchronization issues during data collection.
\n{pformat(outside_tolerances)}"""
)
return False
@@ -604,3 +696,118 @@ class IterableNamespace(SimpleNamespace):
def keys(self):
return vars(self).keys()
+
+
+def validate_frame(frame: dict, features: dict):
+ optional_features = {"timestamp"}
+ expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
+ actual_features = set(frame.keys())
+
+ error_message = validate_features_presence(actual_features, expected_features, optional_features)
+
+ if "task" in frame:
+ error_message += validate_feature_string("task", frame["task"])
+
+ common_features = actual_features & (expected_features | optional_features)
+ for name in common_features - {"task"}:
+ error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
+
+ if error_message:
+ raise ValueError(error_message)
+
+
+def validate_features_presence(
+ actual_features: set[str], expected_features: set[str], optional_features: set[str]
+):
+ error_message = ""
+ missing_features = expected_features - actual_features
+ extra_features = actual_features - (expected_features | optional_features)
+
+ if missing_features or extra_features:
+ error_message += "Feature mismatch in `frame` dictionary:\n"
+ if missing_features:
+ error_message += f"Missing features: {missing_features}\n"
+ if extra_features:
+ error_message += f"Extra features: {extra_features}\n"
+
+ return error_message
+
+
+def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
+ expected_dtype = feature["dtype"]
+ expected_shape = feature["shape"]
+ if is_valid_numpy_dtype_string(expected_dtype):
+ return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
+ elif expected_dtype in ["image", "video"]:
+ return validate_feature_image_or_video(name, expected_shape, value)
+ elif expected_dtype == "string":
+ return validate_feature_string(name, value)
+ else:
+ raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
+
+
+def validate_feature_numpy_array(
+ name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
+):
+ error_message = ""
+ if isinstance(value, np.ndarray):
+ actual_dtype = value.dtype
+ actual_shape = value.shape
+
+ if actual_dtype != np.dtype(expected_dtype):
+ error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
+
+ if actual_shape != expected_shape:
+ error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
+ else:
+ error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
+
+ return error_message
+
+
+def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
+ # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
+ error_message = ""
+ if isinstance(value, np.ndarray):
+ actual_shape = value.shape
+ c, h, w = expected_shape
+ if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
+ error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
+ elif isinstance(value, PILImage.Image):
+ pass
+ else:
+ error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
+
+ return error_message
+
+
+def validate_feature_string(name: str, value: str):
+ if not isinstance(value, str):
+ return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
+ return ""
+
+
+def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
+ if "size" not in episode_buffer:
+ raise ValueError("size key not found in episode_buffer")
+
+ if "task" not in episode_buffer:
+ raise ValueError("task key not found in episode_buffer")
+
+ if episode_buffer["episode_index"] != total_episodes:
+ # TODO(aliberts): Add option to use existing episode_index
+ raise NotImplementedError(
+ "You might have manually provided the episode_buffer with an episode_index that doesn't "
+ "match the total number of episodes already in the dataset. This is not supported for now."
+ )
+
+ if episode_buffer["size"] == 0:
+ raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
+
+ buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
+ if not buffer_keys == set(features):
+ raise ValueError(
+ f"Features from `episode_buffer` don't match the ones in `features`."
+ f"In episode_buffer not in features: {buffer_keys - set(features)}"
+ f"In features not in episode_buffer: {set(features) - buffer_keys}"
+ )
diff --git a/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py
index 4cd93a2d..99ab2cbf 100644
--- a/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py
+++ b/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py
@@ -31,6 +31,7 @@ from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
LOCAL_DIR = Path("data/")
+# spellchecker:off
ALOHA_MOBILE_INFO = {
"robot_config": AlohaRobotConfig(),
"license": "mit",
@@ -856,6 +857,7 @@ DATASETS = {
}""").lstrip(),
},
}
+# spellchecker:on
def batch_convert():
diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py
index 62ca9932..acf0282f 100644
--- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py
+++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py
@@ -17,7 +17,7 @@
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
-for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
+for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
We support 3 different scenarios for these tasks (see instructions below):
1. Single task dataset: all episodes of your dataset have the same single task.
@@ -130,7 +130,7 @@ from lerobot.common.datasets.utils import (
create_branch,
create_lerobot_dataset_card,
flatten_dict,
- get_hub_safe_version,
+ get_safe_version,
load_json,
unflatten_dict,
write_json,
@@ -443,7 +443,7 @@ def convert_dataset(
test_branch: str | None = None,
**card_kwargs,
):
- v1 = get_hub_safe_version(repo_id, V16)
+ v1 = get_safe_version(repo_id, V16)
v1x_dir = local_dir / V16 / repo_id
v20_dir = local_dir / V20 / repo_id
v1x_dir.mkdir(parents=True, exist_ok=True)
diff --git a/lerobot/common/datasets/v21/_remove_language_instruction.py b/lerobot/common/datasets/v21/_remove_language_instruction.py
new file mode 100644
index 00000000..643ddd3f
--- /dev/null
+++ b/lerobot/common/datasets/v21/_remove_language_instruction.py
@@ -0,0 +1,87 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import traceback
+from pathlib import Path
+
+from datasets import get_dataset_config_info
+from huggingface_hub import HfApi
+
+from lerobot import available_datasets
+from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
+from lerobot.common.datasets.utils import INFO_PATH, write_info
+from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
+
+LOCAL_DIR = Path("data/")
+
+hub_api = HfApi()
+
+
+def fix_dataset(repo_id: str) -> str:
+ if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
+ return f"{repo_id}: skipped (not in {V20})."
+
+ dataset_info = get_dataset_config_info(repo_id, "default")
+ with SuppressWarnings():
+ lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
+
+ meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
+ parquet_features = set(dataset_info.features)
+
+ diff_parquet_meta = parquet_features - meta_features
+ diff_meta_parquet = meta_features - parquet_features
+
+ if diff_parquet_meta:
+ raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
+
+ if not diff_meta_parquet:
+ return f"{repo_id}: skipped (no diff)"
+
+ if diff_meta_parquet:
+ logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
+ assert diff_meta_parquet == {"language_instruction"}
+ lerobot_metadata.features.pop("language_instruction")
+ write_info(lerobot_metadata.info, lerobot_metadata.root)
+ commit_info = hub_api.upload_file(
+ path_or_fileobj=lerobot_metadata.root / INFO_PATH,
+ path_in_repo=INFO_PATH,
+ repo_id=repo_id,
+ repo_type="dataset",
+ revision=V20,
+ commit_message="Remove 'language_instruction'",
+ create_pr=True,
+ )
+ return f"{repo_id}: success - PR: {commit_info.pr_url}"
+
+
+def batch_fix():
+ status = {}
+ LOCAL_DIR.mkdir(parents=True, exist_ok=True)
+ logfile = LOCAL_DIR / "fix_features_v20.txt"
+ for num, repo_id in enumerate(available_datasets):
+ print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
+ print("---------------------------------------------------------")
+ try:
+ status = fix_dataset(repo_id)
+ except Exception:
+ status = f"{repo_id}: failed\n {traceback.format_exc()}"
+
+ logging.info(status)
+ with open(logfile, "a") as file:
+ file.write(status + "\n")
+
+
+if __name__ == "__main__":
+ batch_fix()
diff --git a/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py
new file mode 100644
index 00000000..cc9272a8
--- /dev/null
+++ b/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
+"""
+
+import traceback
+from pathlib import Path
+
+from huggingface_hub import HfApi
+
+from lerobot import available_datasets
+from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
+
+LOCAL_DIR = Path("data/")
+
+
+def batch_convert():
+ status = {}
+ LOCAL_DIR.mkdir(parents=True, exist_ok=True)
+ logfile = LOCAL_DIR / "conversion_log_v21.txt"
+ hub_api = HfApi()
+ for num, repo_id in enumerate(available_datasets):
+ print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
+ print("---------------------------------------------------------")
+ try:
+ if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
+ status = f"{repo_id}: success (already in {V21})."
+ else:
+ convert_dataset(repo_id)
+ status = f"{repo_id}: success."
+ except Exception:
+ status = f"{repo_id}: failed\n {traceback.format_exc()}"
+
+ with open(logfile, "a") as file:
+ file.write(status + "\n")
+
+
+if __name__ == "__main__":
+ batch_convert()
diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py
new file mode 100644
index 00000000..176d16d0
--- /dev/null
+++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py
@@ -0,0 +1,114 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
+2.1. It will:
+
+- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
+- Check consistency between these new stats and the old ones.
+- Remove the deprecated `stats.json`.
+- Update codebase_version in `info.json`.
+- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
+
+Usage:
+
+```bash
+python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
+ --repo-id=aliberts/koch_tutorial
+```
+
+"""
+
+import argparse
+import logging
+
+from huggingface_hub import HfApi
+
+from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
+from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
+from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
+
+V20 = "v2.0"
+V21 = "v2.1"
+
+
+class SuppressWarnings:
+ def __enter__(self):
+ self.previous_level = logging.getLogger().getEffectiveLevel()
+ logging.getLogger().setLevel(logging.ERROR)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ logging.getLogger().setLevel(self.previous_level)
+
+
+def convert_dataset(
+ repo_id: str,
+ branch: str | None = None,
+ num_workers: int = 4,
+):
+ with SuppressWarnings():
+ dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
+
+ if (dataset.root / EPISODES_STATS_PATH).is_file():
+ (dataset.root / EPISODES_STATS_PATH).unlink()
+
+ convert_stats(dataset, num_workers=num_workers)
+ ref_stats = load_stats(dataset.root)
+ check_aggregate_stats(dataset, ref_stats)
+
+ dataset.meta.info["codebase_version"] = CODEBASE_VERSION
+ write_info(dataset.meta.info, dataset.root)
+
+ dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
+
+ # delete old stats.json file
+ if (dataset.root / STATS_PATH).is_file:
+ (dataset.root / STATS_PATH).unlink()
+
+ hub_api = HfApi()
+ if hub_api.file_exists(
+ repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
+ ):
+ hub_api.delete_file(
+ path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
+ )
+
+ hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--repo-id",
+ type=str,
+ required=True,
+ help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
+ "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
+ )
+ parser.add_argument(
+ "--branch",
+ type=str,
+ default=None,
+ help="Repo branch to push your dataset. Defaults to the main branch.",
+ )
+ parser.add_argument(
+ "--num-workers",
+ type=int,
+ default=4,
+ help="Number of workers for parallelizing stats compute. Defaults to 4.",
+ )
+
+ args = parser.parse_args()
+ convert_dataset(**vars(args))
diff --git a/lerobot/common/datasets/v21/convert_stats.py b/lerobot/common/datasets/v21/convert_stats.py
new file mode 100644
index 00000000..4a20b427
--- /dev/null
+++ b/lerobot/common/datasets/v21/convert_stats.py
@@ -0,0 +1,99 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+import numpy as np
+from tqdm import tqdm
+
+from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.common.datasets.utils import write_episode_stats
+
+
+def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
+ ep_len = dataset.meta.episodes[episode_index]["length"]
+ sampled_indices = sample_indices(ep_len)
+ query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
+ video_frames = dataset._query_videos(query_timestamps, episode_index)
+ return video_frames[ft_key].numpy()
+
+
+def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
+ ep_start_idx = dataset.episode_data_index["from"][ep_idx]
+ ep_end_idx = dataset.episode_data_index["to"][ep_idx]
+ ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
+
+ ep_stats = {}
+ for key, ft in dataset.features.items():
+ if ft["dtype"] == "video":
+ # We sample only for videos
+ ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
+ else:
+ ep_ft_data = np.array(ep_data[key])
+
+ axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
+ keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
+ ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
+
+ if ft["dtype"] in ["image", "video"]: # remove batch dim
+ ep_stats[key] = {
+ k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
+ }
+
+ dataset.meta.episodes_stats[ep_idx] = ep_stats
+
+
+def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
+ assert dataset.episodes is None
+ print("Computing episodes stats")
+ total_episodes = dataset.meta.total_episodes
+ if num_workers > 0:
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
+ futures = {
+ executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
+ for ep_idx in range(total_episodes)
+ }
+ for future in tqdm(as_completed(futures), total=total_episodes):
+ future.result()
+ else:
+ for ep_idx in tqdm(range(total_episodes)):
+ convert_episode_stats(dataset, ep_idx)
+
+ for ep_idx in tqdm(range(total_episodes)):
+ write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
+
+
+def check_aggregate_stats(
+ dataset: LeRobotDataset,
+ reference_stats: dict[str, dict[str, np.ndarray]],
+ video_rtol_atol: tuple[float] = (1e-2, 1e-2),
+ default_rtol_atol: tuple[float] = (5e-6, 6e-5),
+):
+ """Verifies that the aggregated stats from episodes_stats are close to reference stats."""
+ agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
+ for key, ft in dataset.features.items():
+ # These values might need some fine-tuning
+ if ft["dtype"] == "video":
+ # to account for image sub-sampling
+ rtol, atol = video_rtol_atol
+ else:
+ rtol, atol = default_rtol_atol
+
+ for stat, val in agg_stats[key].items():
+ if key in reference_stats and stat in reference_stats[key]:
+ err_msg = f"feature='{key}' stats='{stat}'"
+ np.testing.assert_allclose(
+ val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
+ )
diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py
index 8ed3318d..c38d570d 100644
--- a/lerobot/common/datasets/video_utils.py
+++ b/lerobot/common/datasets/video_utils.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import importlib
import json
import logging
import subprocess
@@ -29,6 +30,46 @@ from datasets.features.features import register_feature
from PIL import Image
+def get_safe_default_codec():
+ if importlib.util.find_spec("torchcodec"):
+ return "torchcodec"
+ else:
+ logging.warning(
+ "'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
+ )
+ return "pyav"
+
+
+def decode_video_frames(
+ video_path: Path | str,
+ timestamps: list[float],
+ tolerance_s: float,
+ backend: str | None = None,
+) -> torch.Tensor:
+ """
+ Decodes video frames using the specified backend.
+
+ Args:
+ video_path (Path): Path to the video file.
+ timestamps (list[float]): List of timestamps to extract frames.
+ tolerance_s (float): Allowed deviation in seconds for frame retrieval.
+ backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
+
+ Returns:
+ torch.Tensor: Decoded frames.
+
+ Currently supports torchcodec on cpu and pyav.
+ """
+ if backend is None:
+ backend = get_safe_default_codec()
+ if backend == "torchcodec":
+ return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
+ elif backend in ["pyav", "video_reader"]:
+ return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
+ else:
+ raise ValueError(f"Unsupported video backend: {backend}")
+
+
def decode_video_frames_torchvision(
video_path: Path | str,
timestamps: list[float],
@@ -69,11 +110,11 @@ def decode_video_frames_torchvision(
# set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
- first_ts = timestamps[0]
- last_ts = timestamps[-1]
+ first_ts = min(timestamps)
+ last_ts = max(timestamps)
# access closest key frame of the first requested frame
- # Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)
+ # Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader.seek(first_ts, keyframes_only=keyframes_only)
@@ -127,6 +168,81 @@ def decode_video_frames_torchvision(
return closest_frames
+def decode_video_frames_torchcodec(
+ video_path: Path | str,
+ timestamps: list[float],
+ tolerance_s: float,
+ device: str = "cpu",
+ log_loaded_timestamps: bool = False,
+) -> torch.Tensor:
+ """Loads frames associated with the requested timestamps of a video using torchcodec.
+
+ Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
+
+ Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
+ the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
+ that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
+ and all subsequent frames until reaching the requested frame. The number of key frames in a video
+ can be adjusted during encoding to take into account decoding time and video size in bytes.
+ """
+
+ if importlib.util.find_spec("torchcodec"):
+ from torchcodec.decoders import VideoDecoder
+ else:
+ raise ImportError("torchcodec is required but not available.")
+
+ # initialize video decoder
+ decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
+ loaded_frames = []
+ loaded_ts = []
+ # get metadata for frame information
+ metadata = decoder.metadata
+ average_fps = metadata.average_fps
+
+ # convert timestamps to frame indices
+ frame_indices = [round(ts * average_fps) for ts in timestamps]
+
+ # retrieve frames based on indices
+ frames_batch = decoder.get_frames_at(indices=frame_indices)
+
+ for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
+ loaded_frames.append(frame)
+ loaded_ts.append(pts.item())
+ if log_loaded_timestamps:
+ logging.info(f"Frame loaded at timestamp={pts:.4f}")
+
+ query_ts = torch.tensor(timestamps)
+ loaded_ts = torch.tensor(loaded_ts)
+
+ # compute distances between each query timestamp and loaded timestamps
+ dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
+ min_, argmin_ = dist.min(1)
+
+ is_within_tol = min_ < tolerance_s
+ assert is_within_tol.all(), (
+ f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
+ "It means that the closest frame that can be loaded from the video is too far away in time."
+ "This might be due to synchronization issues with timestamps during data collection."
+ "To be safe, we advise to ignore this item during training."
+ f"\nqueried timestamps: {query_ts}"
+ f"\nloaded timestamps: {loaded_ts}"
+ f"\nvideo: {video_path}"
+ )
+
+ # get closest frames to the query timestamps
+ closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
+ closest_ts = loaded_ts[argmin_]
+
+ if log_loaded_timestamps:
+ logging.info(f"{closest_ts=}")
+
+ # convert to float32 in [0,1] range (channel first)
+ closest_frames = closest_frames.type(torch.float32) / 255
+
+ assert len(timestamps) == len(closest_frames)
+ return closest_frames
+
+
def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,
@@ -141,6 +257,7 @@ def encode_video_frames(
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
video_path = Path(video_path)
+ imgs_dir = Path(imgs_dir)
video_path.parent.mkdir(parents=True, exist_ok=True)
ffmpeg_args = OrderedDict(
diff --git a/lerobot/common/envs/__init__.py b/lerobot/common/envs/__init__.py
index a583ffc5..4977d11d 100644
--- a/lerobot/common/envs/__init__.py
+++ b/lerobot/common/envs/__init__.py
@@ -1 +1,15 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py
index 6259ca94..cf90048a 100644
--- a/lerobot/common/envs/configs.py
+++ b/lerobot/common/envs/configs.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import abc
from dataclasses import dataclass, field
diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py
index 49239363..8450f84b 100644
--- a/lerobot/common/envs/factory.py
+++ b/lerobot/common/envs/factory.py
@@ -37,12 +37,12 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
Args:
cfg (EnvConfig): the config of the environment to instantiate.
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
- use_async_envs (bool, optional): Wether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
+ use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
False.
Raises:
ValueError: if n_envs < 1
- ModuleNotFoundError: If the requested env package is not intalled
+ ModuleNotFoundError: If the requested env package is not installed
Returns:
gym.vector.VectorEnv: The parallelized gym.env instance.
diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py
index 30bbaf39..83334f87 100644
--- a/lerobot/common/envs/utils.py
+++ b/lerobot/common/envs/utils.py
@@ -13,7 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import warnings
+from typing import Any
+
import einops
+import gymnasium as gym
import numpy as np
import torch
from torch import Tensor
@@ -86,3 +90,38 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
policy_features[policy_key] = feature
return policy_features
+
+
+def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool:
+ first_type = type(env.envs[0]) # Get type of first env
+ return all(type(e) is first_type for e in env.envs) # Fast type check
+
+
+def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
+ with warnings.catch_warnings():
+ warnings.simplefilter("once", UserWarning) # Apply filter only in this function
+
+ if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
+ warnings.warn(
+ "The environment does not have 'task_description' and 'task'. Some policies require these features.",
+ UserWarning,
+ stacklevel=2,
+ )
+ if not are_all_envs_same_type(env):
+ warnings.warn(
+ "The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+
+def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
+ """Adds task feature to the observation dict with respect to the first environment attribute."""
+ if hasattr(env.envs[0], "task_description"):
+ observation["task"] = env.call("task_description")
+ elif hasattr(env.envs[0], "task"):
+ observation["task"] = env.call("task")
+ else: # For envs without language instructions, e.g. aloha transfer cube and etc.
+ num_envs = observation[list(observation.keys())[0]].shape[0]
+ observation["task"] = ["" for _ in range(num_envs)]
+ return observation
diff --git a/lerobot/common/optim/__init__.py b/lerobot/common/optim/__init__.py
index e1e65966..de2c4c99 100644
--- a/lerobot/common/optim/__init__.py
+++ b/lerobot/common/optim/__init__.py
@@ -1 +1,15 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .optimizers import OptimizerConfig as OptimizerConfig
diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py
index 4f724c12..7a5819b7 100644
--- a/lerobot/common/policies/act/configuration_act.py
+++ b/lerobot/common/policies/act/configuration_act.py
@@ -64,7 +64,7 @@ class ACTConfig(PreTrainedConfig):
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
- pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
+ pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
convolution.
diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py
index f2b16a1e..72d4df03 100644
--- a/lerobot/common/policies/act/modeling_act.py
+++ b/lerobot/common/policies/act/modeling_act.py
@@ -119,9 +119,7 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
- batch["observation.images"] = torch.stack(
- [batch[key] for key in self.config.image_features], dim=-4
- )
+ batch["observation.images"] = [batch[key] for key in self.config.image_features]
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
@@ -149,9 +147,8 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
- batch["observation.images"] = torch.stack(
- [batch[key] for key in self.config.image_features], dim=-4
- )
+ batch["observation.images"] = [batch[key] for key in self.config.image_features]
+
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@@ -413,11 +410,10 @@ class ACT(nn.Module):
"actions must be provided when using the variational objective in training mode."
)
- batch_size = (
- batch["observation.images"]
- if "observation.images" in batch
- else batch["observation.environment_state"]
- ).shape[0]
+ if "observation.images" in batch:
+ batch_size = batch["observation.images"][0].shape[0]
+ else:
+ batch_size = batch["observation.environment_state"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch:
@@ -490,20 +486,21 @@ class ACT(nn.Module):
all_cam_features = []
all_cam_pos_embeds = []
- for cam_index in range(batch["observation.images"].shape[-4]):
- cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
- # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
- # buffer
+ # For a list of images, the H and W may vary but H*W is constant.
+ for img in batch["observation.images"]:
+ cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
- cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
+ cam_features = self.encoder_img_feat_input_proj(cam_features)
+
+ # Rearrange features to (sequence, batch, dim).
+ cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
+ cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
+
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
- # Concatenate camera observation feature maps and positional embeddings along the width dimension,
- # and move to (sequence, batch, dim).
- all_cam_features = torch.cat(all_cam_features, axis=-1)
- encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
- all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
- encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
+
+ encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
+ encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
# Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py
index d571e152..e73c65fe 100644
--- a/lerobot/common/policies/diffusion/configuration_diffusion.py
+++ b/lerobot/common/policies/diffusion/configuration_diffusion.py
@@ -68,7 +68,7 @@ class DiffusionConfig(PreTrainedConfig):
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
- pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
+ pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
@@ -99,7 +99,7 @@ class DiffusionConfig(PreTrainedConfig):
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
- `LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
+ `LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
to False as the original Diffusion Policy implementation does the same.
"""
diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py
index 249ea8cd..4d9b24e8 100644
--- a/lerobot/common/policies/factory.py
+++ b/lerobot/common/policies/factory.py
@@ -26,6 +26,7 @@ from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
+from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
@@ -59,6 +60,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.dexvla.modeling_dexvla import DexVLAPolicy
return DexVLAPolicy
+ elif name == "pi0fast":
+ from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
+
+ return PI0FASTPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -76,6 +81,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "dexvla":
return DexVLAConfig(**kwargs)
+ elif policy_type == "pi0fast":
+ return PI0FASTConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py
index 95219273..b3255ec1 100644
--- a/lerobot/common/policies/normalize.py
+++ b/lerobot/common/policies/normalize.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import numpy as np
import torch
from torch import Tensor, nn
@@ -77,17 +78,29 @@ def create_stats_buffers(
}
)
+ # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats:
- # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
- # tensors anywhere (for example, when we use the same stats for normalization and
- # unnormalization). See the logic here
- # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
- if norm_mode is NormalizationMode.MEAN_STD:
- buffer["mean"].data = stats[key]["mean"].clone()
- buffer["std"].data = stats[key]["std"].clone()
- elif norm_mode is NormalizationMode.MIN_MAX:
- buffer["min"].data = stats[key]["min"].clone()
- buffer["max"].data = stats[key]["max"].clone()
+ if isinstance(stats[key]["mean"], np.ndarray):
+ if norm_mode is NormalizationMode.MEAN_STD:
+ buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
+ buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
+ elif norm_mode is NormalizationMode.MIN_MAX:
+ buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
+ buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
+ elif isinstance(stats[key]["mean"], torch.Tensor):
+ # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
+ # tensors anywhere (for example, when we use the same stats for normalization and
+ # unnormalization). See the logic here
+ # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
+ if norm_mode is NormalizationMode.MEAN_STD:
+ buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
+ buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
+ elif norm_mode is NormalizationMode.MIN_MAX:
+ buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
+ buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
+ else:
+ type_ = type(stats[key]["mean"])
+ raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
stats_buffers[key] = buffer
return stats_buffers
@@ -141,6 +154,7 @@ class Normalize(nn.Module):
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
+ # FIXME(aliberts, rcadene): This might lead to silent fail!
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
diff --git a/lerobot/common/policies/pi0/configuration_pi0.py b/lerobot/common/policies/pi0/configuration_pi0.py
index 8d2eedf6..8c7cc130 100644
--- a/lerobot/common/policies/pi0/configuration_pi0.py
+++ b/lerobot/common/policies/pi0/configuration_pi0.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
@@ -76,6 +90,7 @@ class PI0Config(PreTrainedConfig):
def __post_init__(self):
super().__post_init__()
+ # TODO(Steven): Validate device and amp? in all policy configs?
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
diff --git a/lerobot/common/policies/pi0/conversion_scripts/benchmark.py b/lerobot/common/policies/pi0/conversion_scripts/benchmark.py
index 31bd1b66..cb3c0e9b 100644
--- a/lerobot/common/policies/pi0/conversion_scripts/benchmark.py
+++ b/lerobot/common/policies/pi0/conversion_scripts/benchmark.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@@ -31,7 +45,7 @@ def main():
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
- policy = make_policy(cfg, device, ds_meta=dataset.meta)
+ policy = make_policy(cfg, ds_meta=dataset.meta)
# policy = torch.compile(policy, mode="reduce-overhead")
diff --git a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py
index 8b2e1c66..6bd7c91f 100644
--- a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py
+++ b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import pickle
from pathlib import Path
@@ -87,7 +101,7 @@ def main():
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
- policy = make_policy(cfg, device, dataset_meta)
+ policy = make_policy(cfg, dataset_meta)
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
# loss_dict["loss"].backward()
diff --git a/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py b/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py
index 8e35d0d4..8835da31 100644
--- a/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py
+++ b/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from transformers import GemmaConfig, PaliGemmaConfig
diff --git a/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py
index f85437a5..73ff506f 100644
--- a/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py
+++ b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py
@@ -1,8 +1,22 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
Convert pi0 parameters from Jax to Pytorch
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
-and install the required librairies.
+and install the required libraries.
```bash
cd ~/code/openpi
diff --git a/lerobot/common/policies/pi0/flex_attention.py b/lerobot/common/policies/pi0/flex_attention.py
index 38a5b597..35628cdd 100644
--- a/lerobot/common/policies/pi0/flex_attention.py
+++ b/lerobot/common/policies/pi0/flex_attention.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import torch
import torch.nn.functional as F # noqa: N812
from packaging.version import Version
diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py
index c8b12caf..4462f162 100644
--- a/lerobot/common/policies/pi0/modeling_pi0.py
+++ b/lerobot/common/policies/pi0/modeling_pi0.py
@@ -24,7 +24,7 @@ Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Install pi0 extra dependencies:
```bash
-pip install -e ".[pi0]"
+pip install --no-binary=av -e ".[pi0]"
```
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
@@ -313,7 +313,7 @@ class PI0Policy(PreTrainedPolicy):
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.prepare_action(batch)
- actions_is_pad = batch.get("actions_id_pad")
+ actions_is_pad = batch.get("action_is_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
diff --git a/lerobot/common/policies/pi0/paligemma_with_expert.py b/lerobot/common/policies/pi0/paligemma_with_expert.py
index 08c36c11..76e2ce60 100644
--- a/lerobot/common/policies/pi0/paligemma_with_expert.py
+++ b/lerobot/common/policies/pi0/paligemma_with_expert.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import List, Optional, Union
import torch
diff --git a/lerobot/common/policies/pi0fast/configuration_pi0fast.py b/lerobot/common/policies/pi0fast/configuration_pi0fast.py
new file mode 100644
index 00000000..29c856e0
--- /dev/null
+++ b/lerobot/common/policies/pi0fast/configuration_pi0fast.py
@@ -0,0 +1,136 @@
+from dataclasses import dataclass, field
+
+from lerobot.common.optim.optimizers import AdamWConfig
+from lerobot.common.optim.schedulers import (
+ CosineDecayWithWarmupSchedulerConfig,
+)
+from lerobot.configs.policies import PreTrainedConfig
+from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
+
+
+@PreTrainedConfig.register_subclass("pi0fast")
+@dataclass
+class PI0FASTConfig(PreTrainedConfig):
+ # Input / output structure.
+ n_obs_steps: int = 1
+ chunk_size: int = 10
+ n_action_steps: int = 5
+
+ normalization_mapping: dict[str, NormalizationMode] = field(
+ default_factory=lambda: {
+ "VISUAL": NormalizationMode.IDENTITY,
+ "STATE": NormalizationMode.MEAN_STD,
+ "ACTION": NormalizationMode.MEAN_STD,
+ }
+ )
+
+ # Shorter state and action vectors will be padded
+ max_state_dim: int = 32 # 32
+ max_action_dim: int = 32 # 32
+
+ # Image preprocessing
+ resize_imgs_with_padding: tuple[int, int] = (224, 224)
+ interpolate_like_pi: bool = False
+
+ # Add empty images. Used by pi0_aloha_sim which adds the empty
+ # left and right wrist cameras in addition to the top camera.
+ empty_cameras: int = 0
+
+ # Converts the joint and gripper values from the standard Aloha space to
+ # the space used by the pi internal runtime which was used to train the base model.
+ adapt_to_pi_aloha: bool = False
+
+ # Converts joint dimensions to deltas with respect to the current state before passing to the model.
+ # Gripper dimensions will remain in absolute values.
+ use_delta_joint_actions_aloha: bool = False
+
+ # Tokenizer
+ tokenizer_max_length: int = 48
+
+ # Projector
+ proj_width: int = 1024
+
+ # Decoding
+ max_decoding_steps: int = 256
+ fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
+ max_input_seq_len: int = 256 # 512
+
+ # Utils
+ use_cache: bool = True
+
+ # Frozen parameters
+ freeze_vision_encoder: bool = True
+ freeze_lm_head: bool = True
+
+ # Training presets
+ optimizer_lr: float = 1e-4
+ optimizer_betas: tuple[float, float] = (0.9, 0.95)
+ optimizer_eps: float = 1e-8
+ optimizer_weight_decay: float = 1e-5
+
+ scheduler_warmup_steps: int = 1_000
+ scheduler_decay_steps: int = 30_000
+ scheduler_decay_lr: float = 2.5e-6
+
+ checkpoint_path: str = None
+
+ padding_side: str = "right"
+
+ precision: str = "bfloat16"
+ grad_clip_norm: float = 1
+
+ # Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
+ # In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
+ relaxed_action_decoding: bool = True
+
+ def __post_init__(self):
+ super().__post_init__()
+
+ """Input validation (not exhaustive)."""
+ if self.n_action_steps > self.chunk_size:
+ raise ValueError(
+ f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
+ f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
+ )
+ if self.n_obs_steps != 1:
+ raise ValueError(
+ f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
+ )
+
+ def validate_features(self) -> None:
+ for i in range(self.empty_cameras):
+ key = f"observation.images.empty_camera_{i}"
+ empty_camera = PolicyFeature(
+ type=FeatureType.VISUAL,
+ shape=(3, 480, 640),
+ )
+ self.input_features[key] = empty_camera
+
+ def get_optimizer_preset(self) -> AdamWConfig:
+ return AdamWConfig(
+ lr=self.optimizer_lr,
+ betas=self.optimizer_betas,
+ eps=self.optimizer_eps,
+ weight_decay=self.optimizer_weight_decay,
+ grad_clip_norm=self.grad_clip_norm,
+ )
+
+ def get_scheduler_preset(self):
+ return CosineDecayWithWarmupSchedulerConfig(
+ peak_lr=self.optimizer_lr,
+ decay_lr=self.scheduler_decay_lr,
+ num_warmup_steps=self.scheduler_warmup_steps,
+ num_decay_steps=self.scheduler_decay_steps,
+ )
+
+ @property
+ def observation_delta_indices(self) -> None:
+ return None
+
+ @property
+ def action_delta_indices(self) -> list:
+ return list(range(self.chunk_size))
+
+ @property
+ def reward_delta_indices(self) -> None:
+ return None
diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py
new file mode 100644
index 00000000..36aafce9
--- /dev/null
+++ b/lerobot/common/policies/pi0fast/modeling_pi0fast.py
@@ -0,0 +1,973 @@
+#!/usr/bin/env python
+
+# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models
+
+[Paper](https://arxiv.org/abs/2501.09747)
+[Jax code](https://github.com/Physical-Intelligence/openpi)
+
+Designed by Physical Intelligence. Ported from Jax by Hugging Face.
+
+Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
+```bash
+python lerobot/scripts/train.py \
+--policy.path=lerobot/pi0fast_base \
+--dataset.repo_id=danaaubakirova/koch_test
+```
+
+Example of training the pi0+FAST neural network with from scratch:
+```bash
+python lerobot/scripts/train.py \
+--policy.type=pi0fast \
+--dataset.repo_id=danaaubakirova/koch_test
+```
+
+Example of using the pi0 pretrained model outside LeRobot training framework:
+```python
+policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base")
+```
+
+"""
+
+from collections import deque
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn.functional as F # noqa: N812
+from PIL import Image
+from scipy.fft import idct
+from torch import Tensor, nn
+from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration
+from transformers.cache_utils import HybridCache, StaticCache
+from transformers.models.auto import CONFIG_MAPPING
+
+from lerobot.common.constants import ACTION, OBS_ROBOT
+from lerobot.common.policies.normalize import Normalize, Unnormalize
+from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
+from lerobot.common.policies.pretrained import PreTrainedPolicy
+
+PRECISION = {
+ "float16": torch.float16,
+ "float32": torch.float32,
+ "bfloat16": torch.bfloat16,
+}
+
+
+def normalize(x, min_val, max_val):
+ return (x - min_val) / (max_val - min_val)
+
+
+def unnormalize(x, min_val, max_val):
+ return x * (max_val - min_val) + min_val
+
+
+def safe_arcsin(value):
+ # This ensures that the input stays within
+ # [−1,1] to avoid invalid values for arcsin
+ return torch.arcsin(torch.clamp(value, -1.0, 1.0))
+
+
+def aloha_gripper_to_angular(value):
+ # Aloha transforms the gripper positions into a linear space. The following code
+ # reverses this transformation to be consistent with pi0 which is pretrained in
+ # angular space.
+ #
+ # These values are coming from the Aloha code:
+ # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
+ value = unnormalize(value, min_val=0.01844, max_val=0.05800)
+
+ # This is the inverse of the angular to linear transformation inside the Interbotix code.
+ def linear_to_radian(linear_position, arm_length, horn_radius):
+ value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
+ return safe_arcsin(value)
+
+ # The constants are taken from the Interbotix code.
+ value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
+
+ # Normalize to [0, 1].
+ # The values 0.4 and 1.5 were measured on an actual Trossen robot.
+ return normalize(value, min_val=0.4, max_val=1.5)
+
+
+def aloha_gripper_from_angular(value):
+ # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
+ # Note that the units are still angular but the range is different.
+
+ # The values 0.4 and 1.5 were measured on an actual Trossen robot.
+ value = unnormalize(value, min_val=0.4, max_val=1.5)
+
+ # These values are coming from the Aloha code:
+ # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
+ return normalize(value, min_val=-0.6213, max_val=1.4910)
+
+
+def aloha_gripper_from_angular_inv(value):
+ # Directly inverts the gripper_from_angular function.
+ value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
+ return normalize(value, min_val=0.4, max_val=1.5)
+
+
+class PI0FASTPolicy(PreTrainedPolicy):
+ """Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
+
+ config_class = PI0FASTConfig
+ name = "pi0fast"
+
+ def __init__(
+ self,
+ config: PI0FASTConfig,
+ dataset_stats: dict[str, dict[str, Tensor]] | None = None,
+ ):
+ """
+ Args:
+ config: Policy configuration class instance or None, in which case the default instantiation of
+ the configuration class is used.
+ dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
+ that they will be passed with a call to `load_state_dict` before the policy is used.
+ """
+
+ super().__init__(config)
+ config.validate_features()
+ self.config = config
+
+ self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
+ self.normalize_targets = Normalize(
+ config.output_features, config.normalization_mapping, dataset_stats
+ )
+ self.unnormalize_outputs = Unnormalize(
+ config.output_features, config.normalization_mapping, dataset_stats
+ )
+
+ self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
+ self.model = PI0FAST(config)
+
+ self.reset()
+
+ def reset(self):
+ """This should be called whenever the environment is reset."""
+ self._action_queue = deque([], maxlen=self.config.n_action_steps)
+
+ def get_optim_params(self) -> dict:
+ return self.parameters()
+
+ def _pi_aloha_decode_state(self, state):
+ # Flip the joints.
+ for motor_idx in [1, 2, 8, 9]:
+ state[:, motor_idx] *= -1
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
+ for motor_idx in [6, 13]:
+ state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
+ return state
+
+ def _pi_aloha_encode_actions(self, actions):
+ # Flip the joints.
+ for motor_idx in [1, 2, 8, 9]:
+ actions[:, :, motor_idx] *= -1
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
+ for motor_idx in [6, 13]:
+ actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
+ return actions
+
+ def _pi_aloha_encode_actions_inv(self, actions):
+ # Flip the joints again.
+ for motor_idx in [1, 2, 8, 9]:
+ actions[:, :, motor_idx] *= -1
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
+ for motor_idx in [6, 13]:
+ actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
+ return actions
+
+ @torch.no_grad
+ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
+ """Select a single action given environment observations.
+
+ This method wraps `select_actions` in order to return one action at a time for execution in the
+ environment. It works by managing the actions in a queue and only calling `select_actions` when the
+ queue is empty.
+ """
+ self.eval()
+
+ if self.config.adapt_to_pi_aloha:
+ batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
+
+ batch = self.normalize_inputs(batch)
+
+ # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
+ # querying the policy.
+ if len(self._action_queue) == 0:
+ actions = self.model.generate_actions(batch)
+
+ actions = actions[:, : self.config.n_action_steps]
+
+ original_action_dim = self.config.action_feature.shape[
+ 0
+ ] # self.config.max_action_dim # self.config.action_feature.shape[0]
+ actions = actions[:, :, :original_action_dim]
+
+ actions = self.unnormalize_outputs({"action": actions})["action"]
+
+ if self.config.adapt_to_pi_aloha:
+ actions = self._pi_aloha_encode_actions(actions)
+
+ # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
+ # effectively has shape (n_action_steps, batch_size, *), hence the transpose.
+ self._action_queue.extend(actions.transpose(0, 1))
+ return self._action_queue.popleft()
+
+ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
+ if self.config.adapt_to_pi_aloha:
+ batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
+ batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
+ batch = self.normalize_inputs(batch)
+ batch = self.normalize_targets(batch)
+ loss_dict = self.model.forward(batch)
+ return loss_dict["loss"], loss_dict
+
+
+def block_causal_update_causal_mask(
+ attention_mask,
+ token_type_ids=None,
+ past_key_values=None,
+ cache_position=None,
+ input_tensor=None,
+ attn_implementation: str = "eager",
+ dtype: torch.dtype = "float32",
+):
+ """
+ Update the causal mask during training and generation. It can be customized to different attention masks.
+ """
+ if attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+ using_static_cache = isinstance(past_key_values, StaticCache)
+ min_dtype = torch.finfo(dtype).min
+
+ if input_tensor is None:
+ input_tensor = attention_mask
+
+ inputs_lead_dim, sequence_length = input_tensor.shape[:2]
+
+ if using_static_cache or isinstance(past_key_values, HybridCache):
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else cache_position[0] + sequence_length + 1
+ )
+
+ # Handle precomputed attention masks
+ if attention_mask is not None and attention_mask.dim() == 4:
+ return attention_mask
+
+ # Causal mask initialization
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+
+ # Standard causal masking (triu ensures tokens can only attend to past)
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+
+ # Apply block causal mask
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.to(causal_mask.device).bool()
+ cumsum = torch.cumsum(token_type_ids, dim=1)
+ block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None]
+
+ # Combine causal_mask with block-wise attention mask
+ causal_mask = torch.where(block_causal_mask, 0.0, causal_mask)
+ causal_mask = causal_mask[:, None, :, :]
+ else:
+ # Apply past cache position constraint
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
+ -1, 1
+ )
+ causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
+ else:
+ # Apply past cache position constraint
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
+ -1, 1
+ )
+ causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
+
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits
+ mask_length = attention_mask.shape[-1]
+
+ # Apply padding mask
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+def prepare_inputs_for_generation(
+ # self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ pixel_values=None,
+ attention_mask=None,
+ token_type_ids=None,
+ use_cache=True,
+ num_logits_to_keep=None,
+ labels=None,
+ self=None,
+ **kwargs,
+):
+ # create block causal attention
+ if cache_position[0] > 0 and input_ids.shape[1] > 0:
+ input_tensor = input_ids[:, -1:]
+ new_positions = (
+ torch.ones(
+ (position_ids.shape[0], input_ids.shape[1]),
+ dtype=position_ids.dtype,
+ device=position_ids.device,
+ ).cumsum(-1)
+ + position_ids[:, -1:]
+ )
+ position_ids = torch.cat([position_ids, new_positions], dim=-1)
+ else:
+ input_tensor = inputs_embeds
+ attention_mask = block_causal_update_causal_mask(
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ input_tensor=input_tensor,
+ token_type_ids=token_type_ids,
+ dtype=self.dtype,
+ attn_implementation=self.config.text_config._attn_implementation,
+ )
+ # Overwritten -- custom `position_ids` and `pixel_values` handling
+ model_inputs = self.language_model.prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ use_cache=use_cache,
+ num_logits_to_keep=num_logits_to_keep,
+ token_type_ids=token_type_ids,
+ **kwargs,
+ )
+
+ # Position_ids in Paligemma are 1-indexed
+ if model_inputs.get("position_ids") is not None:
+ model_inputs["position_ids"] += 1
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
+ if cache_position[0] == 0:
+ model_inputs["pixel_values"] = pixel_values
+ is_training = token_type_ids is not None and labels is not None
+ if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
+ input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
+ causal_mask = self._update_causal_mask(
+ attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
+ )
+ model_inputs["attention_mask"] = causal_mask
+
+ return model_inputs
+
+
+class PI0FAST(nn.Module):
+ def __init__(self, config: PI0FASTConfig):
+ super().__init__()
+ self.config = config
+
+ # TODO: move tokenizers in Policy
+ fast_tokenizer_path = "physical-intelligence/fast"
+ pi0_paligemma_path = "google/paligemma-3b-pt-224"
+ self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path)
+ self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path)
+ self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
+ self.fast_skip_tokens = self.config.fast_skip_tokens
+ self.max_input_seq_len = self.config.max_input_seq_len
+ self.action_horizon = self.config.chunk_size
+ self.action_dim = self.config.action_feature.shape[
+ 0
+ ] # self.config.max_action_dim # self.config.action_feature.shape[0]
+ precision = config.precision
+ torch_precision = PRECISION.get(precision, torch.float32)
+ self.pad_token_id = (
+ self.paligemma_tokenizer.pad_token_id
+ if hasattr(self.paligemma_tokenizer, "pad_token_id")
+ else self.paligemma_tokenizer.eos_token_id
+ )
+
+ paligemma_config = CONFIG_MAPPING["paligemma"](
+ transformers_version="4.48.1",
+ _vocab_size=257152,
+ bos_token_id=2,
+ eos_token_id=1,
+ hidden_size=2048,
+ image_token_index=257152,
+ model_type="paligemma",
+ pad_token_id=0,
+ projection_dim=2048,
+ text_config={
+ "hidden_activation": "gelu_pytorch_tanh",
+ "hidden_size": 2048,
+ "intermediate_size": 16384,
+ "model_type": "gemma",
+ "num_attention_heads": 8,
+ "num_hidden_layers": 18,
+ "num_image_tokens": 256,
+ "num_key_value_heads": 1,
+ "torch_dtype": precision,
+ "vocab_size": 257152,
+ "_attn_implementation": "eager",
+ },
+ vision_config={
+ "hidden_size": 1152,
+ "intermediate_size": 4304,
+ "model_type": "siglip_vision_model",
+ "num_attention_heads": 16,
+ "num_hidden_layers": 27,
+ "num_image_tokens": 256,
+ "patch_size": 14,
+ "projection_dim": 2048,
+ "projector_hidden_act": "gelu_pytorch_tanh",
+ "torch_dtype": precision,
+ "vision_use_head": False,
+ },
+ )
+ self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
+
+ self.pi0_paligemma.prepare_inputs_for_generation = partial(
+ prepare_inputs_for_generation, self=self.pi0_paligemma
+ )
+ # change important stuff in bf16
+ params_to_change_dtype = [
+ "language_model",
+ "vision_tower",
+ "multi_modal",
+ ]
+ for name, param in self.pi0_paligemma.named_parameters():
+ if any(selector in name for selector in params_to_change_dtype):
+ param.data = param.data.to(dtype=torch_precision)
+ self.set_requires_grad()
+ self.image_keys = self.config.image_features.keys()
+ self.ignore_index = self.pi0_paligemma.config.ignore_index
+ self.padding_side = self.config.padding_side
+
+ def set_requires_grad(self):
+ if self.config.freeze_vision_encoder:
+ self.pi0_paligemma.vision_tower.eval()
+ for params in self.pi0_paligemma.vision_tower.parameters():
+ params.requires_grad = False
+ # To avoid unused params issue with distributed training
+ if self.config.freeze_lm_head:
+ for name, params in self.pi0_paligemma.named_parameters():
+ if "embed_tokens" in name: # lm heads and embedding layer are tied
+ params.requires_grad = False
+
+ def embed_tokens(self, tokens: torch.Tensor):
+ return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
+
+ def prepare_inputs_for_generation(self, *args, **kwargs):
+ return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs)
+
+ def prepare_images(self, batch):
+ """Preprocess LeRobot batch into Pi0 inputs"""
+ images = []
+ img_masks = []
+ present_img_keys = [key for key in self.image_keys if key in batch]
+ if len(present_img_keys) == 0:
+ raise ValueError(
+ f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
+ )
+
+ # Preprocess image features present in the batch
+ num_empty_cameras = 0
+ for key in self.image_keys:
+ if key in present_img_keys:
+ img = batch[key]
+
+ if self.config.resize_imgs_with_padding is not None:
+ img = resize_with_pad(
+ img,
+ *self.config.resize_imgs_with_padding,
+ pad_value=0,
+ interpolate_like_pi=self.config.interpolate_like_pi,
+ )
+
+ # Normalize from range [0,1] to [-1,1] as expacted by siglip
+ img = img * 2.0 - 1.0
+
+ bsize = img.shape[0]
+ device = img.device
+ mask = torch.ones(bsize, dtype=torch.bool, device=device)
+ else:
+ if num_empty_cameras >= self.config.empty_cameras:
+ continue
+ img = torch.ones_like(img) * -1
+ bsize = img.shape[0]
+ device = img.device
+ mask = torch.ones(bsize, dtype=torch.bool, device=device)
+ num_empty_cameras += 1
+
+ images.append(img)
+ img_masks.append(mask)
+ return images, img_masks
+
+ def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor:
+ mins = actions.amin(dim=(1, 2), keepdim=True) # [0]
+ maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
+ return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
+
+ def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
+ out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
+ return out
+
+ def fast_tokenizer_wrapper(self, actions_norm):
+ """
+ A wrapper for self.fast_tokenizer that ensures batch processing,
+ conversion to PyTorch tensors, and returns a dictionary without padding.
+ """
+ batch_tokens = self.fast_tokenizer(actions_norm)
+ fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt")
+
+ return fast_out
+
+ def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor:
+ token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
+ # Compute cumulative sum mask
+ cumsum_mask = (padded_mask != 0).cumsum(dim=1)
+ # Suffix block (everything after prefix_len)
+ suffix_mask = cumsum_mask > prefix_len
+ token_type_ids = suffix_mask
+ return token_type_ids
+
+ def create_input_tokens(self, state, lang_text, actions=None):
+ bsize = state.shape[0]
+ device = state.device
+ bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
+ discretized = torch.bucketize(state, bins) - 1
+ discretized = discretized[:, :32]
+
+ prefix_texts = []
+ state_text = []
+ for txt, disc in zip(lang_text, discretized, strict=False):
+ cleaned = txt.lower().strip().replace("_", " ")
+ state_str = " ".join(str(val.item()) for val in disc)
+ prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
+ state_text.append(f"State: {state_str};\n")
+
+ prefix_out = self.paligemma_tokenizer(
+ prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
+ )
+ prefix_ids = prefix_out["input_ids"].to(device)
+ prefix_mask = prefix_out["attention_mask"].to(device)
+ prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
+
+ if actions is not None:
+ actions_norm = self.normalize_actions(actions)
+ actions_pad = F.pad(
+ actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0
+ )[:, :, : self.config.max_action_dim]
+ fast_out = self.fast_tokenizer_wrapper(
+ actions_pad.cpu(),
+ )
+ act_ids = fast_out["input_ids"]
+ act_mask = fast_out["attention_mask"].to(device)
+
+ act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
+ # Replace action with 0 to pad tokens
+ act_ids = torch.where(
+ act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
+ self.pad_token_id,
+ act_ids,
+ )
+
+ eos_token = torch.tensor(
+ [self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
+ ).expand(bsize, -1)
+ eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
+ bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
+ bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
+ bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
+ act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
+ act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
+ act_mask = act_mask.to(device)
+ else:
+ act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
+ act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
+ final_ids = torch.cat([prefix_ids, act_ids], dim=1)
+
+ final_mask = torch.cat([prefix_mask, act_mask], dim=1)
+ batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
+
+ # Use tokenizer pad function
+ padded_output = self.paligemma_tokenizer.pad(
+ batch_inputs, padding="longest", max_length=180, return_tensors="pt"
+ )
+ padded_mask = padded_output["attention_mask"]
+
+ # define tensor of padding lengths
+ att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
+
+ token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
+
+ padded_output["padded_mask"] = padded_output.pop("attention_mask")
+ padded_output["attention_mask"] = att_mask
+ # loss is computed not on prefix, and not on padding
+ padded_output["loss_mask"] = att_mask & padded_output["padded_mask"]
+ padded_output["token_type_ids"] = token_type_ids
+ return padded_output
+
+ def shift_padding_side(
+ self,
+ tokens: torch.Tensor,
+ ar_mask: torch.Tensor,
+ padding_mask: torch.Tensor,
+ loss_mask: torch.Tensor,
+ targets: torch.Tensor,
+ token_type_ids: torch.Tensor,
+ padding_side: str = "right",
+ ) -> tuple[torch.Tensor]:
+ if padding_side not in ["right", "left"]:
+ return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids
+
+ new_tokens = torch.empty_like(tokens)
+ new_ar_masks = torch.empty_like(ar_mask)
+ new_padding_mask = torch.empty_like(padding_mask)
+ new_loss_mask = torch.empty_like(loss_mask)
+ new_targets = torch.empty_like(targets)
+ new_token_type_ids = torch.empty_like(token_type_ids)
+ batch_size = tokens.shape[0]
+ for i in range(batch_size):
+ padding_indices = torch.where(padding_mask[i] == 0)[0]
+ non_padding_indices = torch.where(padding_mask[i] == 1)[0]
+ if padding_side == "left":
+ new_indices = torch.cat((padding_indices, non_padding_indices), dim=0)
+ else:
+ new_indices = torch.cat((non_padding_indices, padding_indices), dim=0)
+ new_tokens[i] = tokens[i].index_select(0, new_indices)
+ new_ar_masks[i] = ar_mask[i].index_select(0, new_indices)
+ new_padding_mask[i] = padding_mask[i].index_select(0, new_indices)
+ new_loss_mask[i] = loss_mask[i].index_select(0, new_indices)
+ new_targets[i] = targets[i].index_select(0, new_indices)
+ new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices)
+
+ return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
+
+ def forward(self, batch: dict[str, Tensor]):
+ device = batch[OBS_ROBOT].device
+ # TODO: keep like this or move to the policy .forward
+ images, img_masks = self.prepare_images(batch)
+
+ padded_outs = self.create_input_tokens(
+ state=batch[OBS_ROBOT],
+ lang_text=batch["task"],
+ actions=batch[ACTION],
+ )
+
+ embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
+ images,
+ img_masks,
+ padded_outs["input_ids"],
+ padded_outs["padded_mask"],
+ padded_outs["attention_mask"],
+ padded_outs["loss_mask"],
+ padded_outs["token_type_ids"],
+ padding_side=self.padding_side,
+ )
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
+ token_type_ids = token_type_ids.to(dtype=torch.int64)
+ past_seen_tokens = 0
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device)
+ pad_masks = block_causal_update_causal_mask(
+ attention_mask=pad_masks,
+ past_key_values=None,
+ cache_position=cache_position,
+ input_tensor=embs,
+ token_type_ids=token_type_ids,
+ dtype=self.pi0_paligemma.dtype,
+ attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation,
+ )
+ outputs = self.pi0_paligemma.forward(
+ input_ids=None,
+ token_type_ids=None,
+ attention_mask=pad_masks,
+ position_ids=position_ids,
+ past_key_values=None,
+ inputs_embeds=embs,
+ use_cache=False,
+ labels=None,
+ )
+
+ logits = outputs.logits
+
+ loss_fct = nn.CrossEntropyLoss(reduction="none")
+
+ # Shift left for next-step prediction
+ logits = logits[:, :-1, :]
+ targets = targets[:, 1:].to(device) # Shift targets
+ loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape
+
+ # Compute per-token loss
+ token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
+
+ # Apply loss mask
+ token_loss = token_loss * loss_mask.reshape(-1)
+
+ # Compute final loss
+ loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)
+
+ # Return loss dictionary
+ loss_dict = {"ce_loss": loss.item(), "loss": loss}
+ return loss_dict
+
+ def decode_actions_with_fast(
+ self,
+ tokens: list[list[int]],
+ *,
+ time_horizon: int | None = None,
+ action_dim: int | None = None,
+ relaxed_decoding: bool = True,
+ ) -> np.array:
+ """
+ Adapt original decoding in FAST to always return actions instead of zeros.
+ """
+ self.time_horizon = (
+ time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
+ )
+ self.action_dim = (
+ action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim
+ )
+
+ # Cache the time horizon and action dimension for the next call
+ self.called_time_horizon = self.time_horizon
+ self.called_action_dim = self.action_dim
+
+ assert self.time_horizon is not None and self.action_dim is not None, (
+ "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
+ )
+
+ decoded_actions = []
+ for token in tokens:
+ try:
+ decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
+ decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
+ if relaxed_decoding:
+ # Expected sequence length
+ expected_seq_len = self.time_horizon * self.action_dim
+ diff = expected_seq_len - decoded_dct_coeff.shape[0]
+ # Apply truncation if too long
+ if diff < 0:
+ decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
+ # Apply padding if too short
+ elif diff > 0:
+ decoded_dct_coeff = np.pad(
+ decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
+ )
+
+ decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
+ assert decoded_dct_coeff.shape == (
+ self.time_horizon,
+ self.action_dim,
+ ), (
+ f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
+ )
+ except Exception as e:
+ print(f"Error decoding tokens: {e}")
+ print(f"Tokens: {token}")
+ decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
+ decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho"))
+ return np.stack(decoded_actions)
+
+ def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
+ """
+ Extracts actions from predicted output tokens using the FAST model.
+
+ Args:
+ tokens (torch.Tensor): The input tensor of tokenized outputs.
+ action_horizon (int): The number of timesteps for actions.
+ action_dim (int): The dimensionality of each action.
+
+ Returns:
+ torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim).
+ """
+ # Decode predicted output tokens
+ decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
+ cleaned_tokens = [
+ tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
+ for tokens_sequence in decoded_tokens
+ ]
+ raw_action_tokens = [
+ self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
+ for sample_tokens in cleaned_tokens
+ ] # something like this should be robust #looks good
+ action_tokens = [
+ self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
+ ]
+ # returns the tensor of decoded actions per sample in a list
+ decoded_actions = [
+ torch.tensor(
+ self.decode_actions_with_fast(
+ tok.tolist(),
+ time_horizon=action_horizon,
+ action_dim=action_dim,
+ relaxed_decoding=self.config.relaxed_action_decoding,
+ ),
+ device=tokens.device,
+ ).squeeze(0)
+ for tok in action_tokens
+ ]
+
+ return torch.stack(
+ decoded_actions,
+ dim=0,
+ )
+
+ def generate_actions(self, batch: dict[str, Tensor]):
+ # TODO: keep like this or move to the policy .forward
+ images, img_masks = self.prepare_images(batch)
+
+ padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None)
+ embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
+ images,
+ img_masks,
+ padded_outs["input_ids"],
+ padded_outs["padded_mask"],
+ padded_outs["attention_mask"],
+ padded_outs["loss_mask"],
+ padded_outs["token_type_ids"],
+ padding_side="left",
+ )
+ token_type_ids = token_type_ids.to(dtype=torch.int64)
+ prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1
+ output_tokens = self.pi0_paligemma.generate(
+ input_ids=None,
+ attention_mask=pad_masks,
+ position_ids=prefix_position_ids,
+ past_key_values=None,
+ inputs_embeds=embs,
+ use_cache=self.config.use_cache,
+ max_new_tokens=self.config.max_decoding_steps,
+ do_sample=False,
+ num_beams=1,
+ token_type_ids=token_type_ids,
+ )
+ actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
+ return actions
+
+ def embed_image(self, image: torch.Tensor):
+ return self.pi0_paligemma.get_image_features(image)
+
+ def embed_inputs(
+ self,
+ images,
+ img_masks,
+ tokens,
+ pad_mask,
+ ar_mask,
+ loss_mask,
+ token_type_ids,
+ padding_side: str = "right",
+ ):
+ # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
+ # images are a list of same size
+ # vectorizing everything!
+ device = images[0].device
+ image_embedding_dim = images[0].shape[-1] # TODO should be from self.config
+ all_images = torch.stack(images, dim=1).to(device)
+ b, n, c, h, w = all_images.shape
+ all_images = all_images.view(b * n, c, h, w)
+ embedded = self.embed_image(all_images).to(device)
+ b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions
+ m = b_n // b # Compute the number of images per sample dynamically
+
+ # Reshape dynamically
+ embedded = embedded.view(b, m, p, image_embedding_dim)
+ tokens_embs = self.embed_tokens(tokens.to(device))
+
+ img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device)
+ num_img_emb = embedded.shape[2]
+ img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1)
+ img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
+
+ image_target_tokens = (
+ torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id
+ ).reshape(b, -1)
+ image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
+
+ embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D)
+
+ embs = torch.cat([embedded, tokens_embs], dim=1).to(device)
+ pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1)
+ att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1)
+ loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1)
+ targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1)
+ token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1)
+
+ # Shift pad tokens to the left (.generate()) or right (.train())
+ embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side(
+ embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side
+ )
+
+ targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets)
+ return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids
+
+
+def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
+ # assume no-op when width height fits already
+ if img.ndim != 4:
+ raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
+
+ cur_height, cur_width = img.shape[2:]
+
+ ratio = max(cur_width / width, cur_height / height)
+ resized_height = int(cur_height / ratio)
+ resized_width = int(cur_width / ratio)
+
+ if interpolate_like_pi:
+ img = (img * 255.0).to(dtype=torch.uint8)
+ img = img.permute(0, 2, 3, 1)
+ original_device = img.device
+ img = img.to(device="cpu").numpy()
+ imgs = []
+ for sub_img in img:
+ sub_img = Image.fromarray(sub_img)
+ resized_img = sub_img.resize((resized_width, resized_height), resample=2)
+ resized_img = torch.from_numpy(np.array(resized_img))
+ imgs.append(resized_img)
+ img = torch.stack(imgs, dim=0)
+ img = img.permute(0, 3, 1, 2)
+ resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0
+ else:
+ resized_img = F.interpolate(
+ img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
+ )
+
+ pad_height = max(0, int(height - resized_height))
+ pad_width = max(0, int(width - resized_width))
+
+ # pad on left and top of image
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
+ return padded_img
diff --git a/lerobot/common/policies/pretrained.py b/lerobot/common/policies/pretrained.py
index 1729dfb0..da4ef157 100644
--- a/lerobot/common/policies/pretrained.py
+++ b/lerobot/common/policies/pretrained.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import abc
import logging
import os
@@ -73,7 +86,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
- map_location: str = "cpu",
strict: bool = False,
**kwargs,
) -> T:
@@ -98,7 +110,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
- policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
+ policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
else:
try:
model_file = hf_hub_download(
@@ -112,13 +124,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
token=token,
local_files_only=local_files_only,
)
- policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
+ policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e
- policy.to(map_location)
+ policy.to(config.device)
policy.eval()
return policy
diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py
index c3e8aee6..3fce01df 100644
--- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py
+++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py
@@ -76,7 +76,7 @@ class TDMPCConfig(PreTrainedConfig):
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
be zero.
uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating
- trajectory values (this is the λ coeffiecient in eqn 4 of FOWM).
+ trajectory values (this is the λ coefficient in eqn 4 of FOWM).
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
elites, when updating the gaussian parameters for CEM.
@@ -165,7 +165,7 @@ class TDMPCConfig(PreTrainedConfig):
"""Input validation (not exhaustive)."""
if self.n_gaussian_samples <= 0:
raise ValueError(
- f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
+ f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
)
if self.normalization_mapping["ACTION"] is not NormalizationMode.MIN_MAX:
raise ValueError(
diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py
index 0940f198..b46ae903 100644
--- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py
+++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py
@@ -122,7 +122,7 @@ class TDMPCPolicy(PreTrainedPolicy):
# When the action queue is depleted, populate it again by querying the policy.
if len(self._queues["action"]) == 0:
- batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
+ batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
# Remove the time dimensions as it is not handled yet.
for key in batch:
diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py
index 59389d6e..28e9c433 100644
--- a/lerobot/common/policies/vqbet/configuration_vqbet.py
+++ b/lerobot/common/policies/vqbet/configuration_vqbet.py
@@ -66,7 +66,7 @@ class VQBeTConfig(PreTrainedConfig):
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
- pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
+ pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py
index 1f70b186..97a08e2f 100644
--- a/lerobot/common/policies/vqbet/modeling_vqbet.py
+++ b/lerobot/common/policies/vqbet/modeling_vqbet.py
@@ -485,7 +485,7 @@ class VQBeTHead(nn.Module):
def forward(self, x, **kwargs) -> dict:
# N is the batch size, and T is number of action query tokens, which are process through same GPT
N, T, _ = x.shape
- # we calculate N and T side parallely. Thus, the dimensions would be
+ # we calculate N and T side parallelly. Thus, the dimensions would be
# (batch size * number of action query tokens, action chunk size, action dimension)
x = einops.rearrange(x, "N T WA -> (N T) WA")
@@ -772,7 +772,7 @@ class VqVae(nn.Module):
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
The vq_layer uses residual VQs.
- This class contains functions for training the encoder and decoder along with the residual VQ layer (for trainign phase 1),
+ This class contains functions for training the encoder and decoder along with the residual VQ layer (for training phase 1),
as well as functions to help BeT training part in training phase 2.
"""
diff --git a/lerobot/common/policies/vqbet/vqbet_utils.py b/lerobot/common/policies/vqbet/vqbet_utils.py
index a2bd2df3..139d119e 100644
--- a/lerobot/common/policies/vqbet/vqbet_utils.py
+++ b/lerobot/common/policies/vqbet/vqbet_utils.py
@@ -38,7 +38,7 @@ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
This file is part of a VQ-BeT that utilizes code from the following repositories:
- Vector Quantize PyTorch code is licensed under the MIT License:
- Origianl source: https://github.com/lucidrains/vector-quantize-pytorch
+ Original source: https://github.com/lucidrains/vector-quantize-pytorch
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
Original source: https://github.com/karpathy/nanoGPT
@@ -289,7 +289,7 @@ class GPT(nn.Module):
This file is a part for Residual Vector Quantization that utilizes code from the following repository:
- Phil Wang's vector-quantize-pytorch implementation in PyTorch.
- Origianl source: https://github.com/lucidrains/vector-quantize-pytorch
+ Original source: https://github.com/lucidrains/vector-quantize-pytorch
- The vector-quantize-pytorch code is licensed under the MIT License:
@@ -1349,9 +1349,9 @@ class EuclideanCodebook(nn.Module):
# calculate distributed variance
- variance_numer = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
- distributed.all_reduce(variance_numer)
- batch_variance = variance_numer / num_vectors
+ variance_number = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
+ distributed.all_reduce(variance_number)
+ batch_variance = variance_number / num_vectors
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
diff --git a/lerobot/common/robot_devices/cameras/configs.py b/lerobot/common/robot_devices/cameras/configs.py
index 6acdbd3e..013419a9 100644
--- a/lerobot/common/robot_devices/cameras/configs.py
+++ b/lerobot/common/robot_devices/cameras/configs.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import abc
from dataclasses import dataclass
diff --git a/lerobot/common/robot_devices/cameras/intelrealsense.py b/lerobot/common/robot_devices/cameras/intelrealsense.py
index 7e65dba9..7a21661a 100644
--- a/lerobot/common/robot_devices/cameras/intelrealsense.py
+++ b/lerobot/common/robot_devices/cameras/intelrealsense.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
This file contains utilities for recording frames from Intel Realsense cameras.
"""
@@ -34,7 +48,7 @@ def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
connected to the computer.
"""
if mock:
- import tests.mock_pyrealsense2 as rs
+ import tests.cameras.mock_pyrealsense2 as rs
else:
import pyrealsense2 as rs
@@ -86,7 +100,7 @@ def save_images_from_cameras(
serial_numbers = [cam["serial_number"] for cam in camera_infos]
if mock:
- import tests.mock_cv2 as cv2
+ import tests.cameras.mock_cv2 as cv2
else:
import cv2
@@ -100,7 +114,7 @@ def save_images_from_cameras(
camera = IntelRealSenseCamera(config)
camera.connect()
print(
- f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
+ f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})"
)
cameras.append(camera)
@@ -210,9 +224,20 @@ class IntelRealSenseCamera:
self.serial_number = self.find_serial_number_from_name(config.name)
else:
self.serial_number = config.serial_number
+
+ # Store the raw (capture) resolution from the config.
+ self.capture_width = config.width
+ self.capture_height = config.height
+
+ # If rotated by ±90, swap width and height.
+ if config.rotation in [-90, 90]:
+ self.width = config.height
+ self.height = config.width
+ else:
+ self.width = config.width
+ self.height = config.height
+
self.fps = config.fps
- self.width = config.width
- self.height = config.height
self.channels = config.channels
self.color_mode = config.color_mode
self.use_depth = config.use_depth
@@ -228,11 +253,10 @@ class IntelRealSenseCamera:
self.logs = {}
if self.mock:
- import tests.mock_cv2 as cv2
+ import tests.cameras.mock_cv2 as cv2
else:
import cv2
- # TODO(alibets): Do we keep original width/height or do we define them after rotation?
self.rotation = None
if config.rotation == -90:
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
@@ -263,22 +287,26 @@ class IntelRealSenseCamera:
)
if self.mock:
- import tests.mock_pyrealsense2 as rs
+ import tests.cameras.mock_pyrealsense2 as rs
else:
import pyrealsense2 as rs
config = rs.config()
config.enable_device(str(self.serial_number))
- if self.fps and self.width and self.height:
+ if self.fps and self.capture_width and self.capture_height:
# TODO(rcadene): can we set rgb8 directly?
- config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
+ config.enable_stream(
+ rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
+ )
else:
config.enable_stream(rs.stream.color)
if self.use_depth:
- if self.fps and self.width and self.height:
- config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
+ if self.fps and self.capture_width and self.capture_height:
+ config.enable_stream(
+ rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
+ )
else:
config.enable_stream(rs.stream.depth)
@@ -316,18 +344,18 @@ class IntelRealSenseCamera:
raise OSError(
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
)
- if self.width is not None and self.width != actual_width:
+ if self.capture_width is not None and self.capture_width != actual_width:
raise OSError(
- f"Can't set {self.width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
+ f"Can't set {self.capture_width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
)
- if self.height is not None and self.height != actual_height:
+ if self.capture_height is not None and self.capture_height != actual_height:
raise OSError(
- f"Can't set {self.height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
+ f"Can't set {self.capture_height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
)
self.fps = round(actual_fps)
- self.width = round(actual_width)
- self.height = round(actual_height)
+ self.capture_width = round(actual_width)
+ self.capture_height = round(actual_height)
self.is_connected = True
@@ -347,7 +375,7 @@ class IntelRealSenseCamera:
)
if self.mock:
- import tests.mock_cv2 as cv2
+ import tests.cameras.mock_cv2 as cv2
else:
import cv2
@@ -373,7 +401,7 @@ class IntelRealSenseCamera:
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
h, w, _ = color_image.shape
- if h != self.height or w != self.width:
+ if h != self.capture_height or w != self.capture_width:
raise OSError(
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
)
@@ -395,7 +423,7 @@ class IntelRealSenseCamera:
depth_map = np.asanyarray(depth_frame.get_data())
h, w = depth_map.shape
- if h != self.height or w != self.width:
+ if h != self.capture_height or w != self.capture_width:
raise OSError(
f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
)
diff --git a/lerobot/common/robot_devices/cameras/opencv.py b/lerobot/common/robot_devices/cameras/opencv.py
index 93c791fa..f279f315 100644
--- a/lerobot/common/robot_devices/cameras/opencv.py
+++ b/lerobot/common/robot_devices/cameras/opencv.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
"""
@@ -66,7 +80,7 @@ def _find_cameras(
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
) -> list[int | str]:
if mock:
- import tests.mock_cv2 as cv2
+ import tests.cameras.mock_cv2 as cv2
else:
import cv2
@@ -130,8 +144,8 @@ def save_images_from_cameras(
camera = OpenCVCamera(config)
camera.connect()
print(
- f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
- f"height={camera.height}, color_mode={camera.color_mode})"
+ f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.capture_width}, "
+ f"height={camera.capture_height}, color_mode={camera.color_mode})"
)
cameras.append(camera)
@@ -230,9 +244,19 @@ class OpenCVCamera:
else:
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
+ # Store the raw (capture) resolution from the config.
+ self.capture_width = config.width
+ self.capture_height = config.height
+
+ # If rotated by ±90, swap width and height.
+ if config.rotation in [-90, 90]:
+ self.width = config.height
+ self.height = config.width
+ else:
+ self.width = config.width
+ self.height = config.height
+
self.fps = config.fps
- self.width = config.width
- self.height = config.height
self.channels = config.channels
self.color_mode = config.color_mode
self.mock = config.mock
@@ -245,11 +269,10 @@ class OpenCVCamera:
self.logs = {}
if self.mock:
- import tests.mock_cv2 as cv2
+ import tests.cameras.mock_cv2 as cv2
else:
import cv2
- # TODO(aliberts): Do we keep original width/height or do we define them after rotation?
self.rotation = None
if config.rotation == -90:
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
@@ -263,7 +286,7 @@ class OpenCVCamera:
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
if self.mock:
- import tests.mock_cv2 as cv2
+ import tests.cameras.mock_cv2 as cv2
else:
import cv2
@@ -271,10 +294,20 @@ class OpenCVCamera:
# when other threads are used to save the images.
cv2.setNumThreads(1)
+ backend = (
+ cv2.CAP_V4L2
+ if platform.system() == "Linux"
+ else cv2.CAP_DSHOW
+ if platform.system() == "Windows"
+ else cv2.CAP_AVFOUNDATION
+ if platform.system() == "Darwin"
+ else cv2.CAP_ANY
+ )
+
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
# First create a temporary camera trying to access `camera_index`,
# and verify it is a valid camera by calling `isOpened`.
- tmp_camera = cv2.VideoCapture(camera_idx)
+ tmp_camera = cv2.VideoCapture(camera_idx, backend)
is_camera_open = tmp_camera.isOpened()
# Release camera to make it accessible for `find_camera_indices`
tmp_camera.release()
@@ -297,14 +330,14 @@ class OpenCVCamera:
# Secondly, create the camera that will be used downstream.
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
# needs to be re-created.
- self.camera = cv2.VideoCapture(camera_idx)
+ self.camera = cv2.VideoCapture(camera_idx, backend)
if self.fps is not None:
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
- if self.width is not None:
- self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
- if self.height is not None:
- self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
+ if self.capture_width is not None:
+ self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.capture_width)
+ if self.capture_height is not None:
+ self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.capture_height)
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
@@ -316,19 +349,22 @@ class OpenCVCamera:
raise OSError(
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
)
- if self.width is not None and not math.isclose(self.width, actual_width, rel_tol=1e-3):
+ if self.capture_width is not None and not math.isclose(
+ self.capture_width, actual_width, rel_tol=1e-3
+ ):
raise OSError(
- f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
+ f"Can't set {self.capture_width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
)
- if self.height is not None and not math.isclose(self.height, actual_height, rel_tol=1e-3):
+ if self.capture_height is not None and not math.isclose(
+ self.capture_height, actual_height, rel_tol=1e-3
+ ):
raise OSError(
- f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
+ f"Can't set {self.capture_height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
)
self.fps = round(actual_fps)
- self.width = round(actual_width)
- self.height = round(actual_height)
-
+ self.capture_width = round(actual_width)
+ self.capture_height = round(actual_height)
self.is_connected = True
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
@@ -362,14 +398,14 @@ class OpenCVCamera:
# so we convert the image color from BGR to RGB.
if requested_color_mode == "rgb":
if self.mock:
- import tests.mock_cv2 as cv2
+ import tests.cameras.mock_cv2 as cv2
else:
import cv2
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
h, w, _ = color_image.shape
- if h != self.height or w != self.width:
+ if h != self.capture_height or w != self.capture_width:
raise OSError(
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
)
diff --git a/lerobot/common/robot_devices/cameras/utils.py b/lerobot/common/robot_devices/cameras/utils.py
index 88288ea3..c6431646 100644
--- a/lerobot/common/robot_devices/cameras/utils.py
+++ b/lerobot/common/robot_devices/cameras/utils.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import Protocol
import numpy as np
@@ -31,7 +45,7 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[C
cameras[key] = IntelRealSenseCamera(cfg)
else:
- raise ValueError(f"The motor type '{cfg.type}' is not valid.")
+ raise ValueError(f"The camera type '{cfg.type}' is not valid.")
return cameras
diff --git a/lerobot/common/robot_devices/control_configs.py b/lerobot/common/robot_devices/control_configs.py
index 6dae8cb6..0ecd8683 100644
--- a/lerobot/common/robot_devices/control_configs.py
+++ b/lerobot/common/robot_devices/control_configs.py
@@ -1,14 +1,25 @@
-import logging
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass
from pathlib import Path
import draccus
from lerobot.common.robot_devices.robots.configs import RobotConfig
-from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.train import TrainPipelineConfig
@dataclass
@@ -43,11 +54,6 @@ class RecordControlConfig(ControlConfig):
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
policy: PreTrainedConfig | None = None
- # TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
- device: str | None = None # cuda | cpu | mps
- # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
- # automatic gradient scaling is used.
- use_amp: bool | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int | None = None
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
@@ -60,15 +66,13 @@ class RecordControlConfig(ControlConfig):
num_episodes: int = 50
# Encode frames in the dataset into video
video: bool = True
- # By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.
- run_compute_stats: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
- # Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only;
+ # Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
@@ -83,9 +87,6 @@ class RecordControlConfig(ControlConfig):
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
- # TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
- # Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
- local_files_only: bool = False
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
@@ -95,27 +96,6 @@ class RecordControlConfig(ControlConfig):
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
- # When no device or use_amp are given, use the one from training config.
- if self.device is None or self.use_amp is None:
- train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
- if self.device is None:
- self.device = train_cfg.device
- if self.use_amp is None:
- self.use_amp = train_cfg.use_amp
-
- # Automatically switch to available device if necessary
- if not is_torch_device_available(self.device):
- auto_device = auto_select_torch_device()
- logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
- self.device = auto_device
-
- # Automatically deactivate AMP if necessary
- if self.use_amp and not is_amp_available(self.device):
- logging.warning(
- f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
- )
- self.use_amp = False
-
@ControlConfig.register_subclass("replay")
@dataclass
@@ -130,9 +110,6 @@ class ReplayControlConfig(ControlConfig):
fps: int | None = None
# Use vocal synthesis to read events.
play_sounds: bool = True
- # TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
- # Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
- local_files_only: bool = False
@ControlConfig.register_subclass("remote_robot")
diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py
index 7264f078..78a8c6a6 100644
--- a/lerobot/common/robot_devices/control_utils.py
+++ b/lerobot/common/robot_devices/control_utils.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
########################################################################################
# Utilities
########################################################################################
@@ -12,13 +26,13 @@ from functools import cache
import cv2
import torch
-import tqdm
from deepdiff import DeepDiff
from termcolor import colored
from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot
+from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, has_method
@@ -180,9 +194,8 @@ def record_episode(
episode_time_s,
display_cameras,
policy,
- device,
- use_amp,
fps,
+ single_task,
):
control_loop(
robot=robot,
@@ -191,10 +204,9 @@ def record_episode(
dataset=dataset,
events=events,
policy=policy,
- device=device,
- use_amp=use_amp,
fps=fps,
teleoperate=policy is None,
+ single_task=single_task,
)
@@ -206,10 +218,9 @@ def control_loop(
display_cameras=False,
dataset: LeRobotDataset | None = None,
events=None,
- policy=None,
- device: torch.device | str | None = None,
- use_amp: bool | None = None,
+ policy: PreTrainedPolicy = None,
fps: int | None = None,
+ single_task: str | None = None,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
@@ -224,12 +235,12 @@ def control_loop(
if teleoperate and policy is not None:
raise ValueError("When `teleoperate` is True, `policy` should be None.")
+ if dataset is not None and single_task is None:
+ raise ValueError("You need to provide a task as argument in `single_task`.")
+
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
- if isinstance(device, str):
- device = get_safe_torch_device(device)
-
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
@@ -241,14 +252,16 @@ def control_loop(
observation = robot.capture_observation()
if policy is not None:
- pred_action = predict_action(observation, policy, device, use_amp)
+ pred_action = predict_action(
+ observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
+ )
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = robot.send_action(pred_action)
action = {"action": action}
if dataset is not None:
- frame = {**observation, **action}
+ frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
if display_cameras and not is_headless():
@@ -270,24 +283,18 @@ def control_loop(
break
-def reset_environment(robot, events, reset_time_s):
+def reset_environment(robot, events, reset_time_s, fps):
# TODO(rcadene): refactor warmup_record and reset_environment
- # TODO(alibets): allow for teleop during reset
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
- timestamp = 0
- start_vencod_t = time.perf_counter()
-
- # Wait if necessary
- with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
- while timestamp < reset_time_s:
- time.sleep(1)
- timestamp = time.perf_counter() - start_vencod_t
- pbar.update(1)
- if events["exit_early"]:
- events["exit_early"] = False
- break
+ control_loop(
+ robot=robot,
+ control_time_s=reset_time_s,
+ events=events,
+ fps=fps,
+ teleoperate=True,
+ )
def stop_recording(robot, listener, display_cameras):
diff --git a/lerobot/common/robot_devices/motors/configs.py b/lerobot/common/robot_devices/motors/configs.py
index 37b781f9..0bfbaf83 100644
--- a/lerobot/common/robot_devices/motors/configs.py
+++ b/lerobot/common/robot_devices/motors/configs.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import abc
from dataclasses import dataclass
diff --git a/lerobot/common/robot_devices/motors/dynamixel.py b/lerobot/common/robot_devices/motors/dynamixel.py
index 54836d8e..6096ceb5 100644
--- a/lerobot/common/robot_devices/motors/dynamixel.py
+++ b/lerobot/common/robot_devices/motors/dynamixel.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import enum
import logging
import math
@@ -242,7 +256,7 @@ class DriveMode(enum.Enum):
class CalibrationMode(enum.Enum):
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
DEGREE = 0
- # Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
+ # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
LINEAR = 1
@@ -318,7 +332,7 @@ class DynamixelMotorsBus:
)
if self.mock:
- import tests.mock_dynamixel_sdk as dxl
+ import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -342,7 +356,7 @@ class DynamixelMotorsBus:
def reconnect(self):
if self.mock:
- import tests.mock_dynamixel_sdk as dxl
+ import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -610,7 +624,7 @@ class DynamixelMotorsBus:
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
- # Substract the homing offsets to come back to actual motor range of values
+ # Subtract the homing offsets to come back to actual motor range of values
# which can be arbitrary.
values[i] -= homing_offset
@@ -632,7 +646,7 @@ class DynamixelMotorsBus:
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock:
- import tests.mock_dynamixel_sdk as dxl
+ import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -677,7 +691,7 @@ class DynamixelMotorsBus:
start_time = time.perf_counter()
if self.mock:
- import tests.mock_dynamixel_sdk as dxl
+ import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -743,7 +757,7 @@ class DynamixelMotorsBus:
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock:
- import tests.mock_dynamixel_sdk as dxl
+ import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
@@ -779,7 +793,7 @@ class DynamixelMotorsBus:
start_time = time.perf_counter()
if self.mock:
- import tests.mock_dynamixel_sdk as dxl
+ import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
diff --git a/lerobot/common/robot_devices/motors/feetech.py b/lerobot/common/robot_devices/motors/feetech.py
index a59db7df..64c7f413 100644
--- a/lerobot/common/robot_devices/motors/feetech.py
+++ b/lerobot/common/robot_devices/motors/feetech.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import enum
import logging
import math
@@ -221,7 +235,7 @@ class DriveMode(enum.Enum):
class CalibrationMode(enum.Enum):
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
DEGREE = 0
- # Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
+ # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
LINEAR = 1
@@ -299,7 +313,7 @@ class FeetechMotorsBus:
)
if self.mock:
- import tests.mock_scservo_sdk as scs
+ import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -323,7 +337,7 @@ class FeetechMotorsBus:
def reconnect(self):
if self.mock:
- import tests.mock_scservo_sdk as scs
+ import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -591,7 +605,7 @@ class FeetechMotorsBus:
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
- # Substract the homing offsets to come back to actual motor range of values
+ # Subtract the homing offsets to come back to actual motor range of values
# which can be arbitrary.
values[i] -= homing_offset
@@ -632,7 +646,7 @@ class FeetechMotorsBus:
track["prev"][idx] = values[i]
continue
- # Detect a full rotation occured
+ # Detect a full rotation occurred
if abs(track["prev"][idx] - values[i]) > 2048:
# Position went below 0 and got reset to 4095
if track["prev"][idx] < values[i]:
@@ -650,7 +664,7 @@ class FeetechMotorsBus:
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock:
- import tests.mock_scservo_sdk as scs
+ import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -688,7 +702,7 @@ class FeetechMotorsBus:
def read(self, data_name, motor_names: str | list[str] | None = None):
if self.mock:
- import tests.mock_scservo_sdk as scs
+ import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -768,7 +782,7 @@ class FeetechMotorsBus:
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock:
- import tests.mock_scservo_sdk as scs
+ import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
@@ -804,7 +818,7 @@ class FeetechMotorsBus:
start_time = time.perf_counter()
if self.mock:
- import tests.mock_scservo_sdk as scs
+ import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
diff --git a/lerobot/common/robot_devices/motors/utils.py b/lerobot/common/robot_devices/motors/utils.py
index fc64f050..bd86f4c6 100644
--- a/lerobot/common/robot_devices/motors/utils.py
+++ b/lerobot/common/robot_devices/motors/utils.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import Protocol
from lerobot.common.robot_devices.motors.configs import (
diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py
index 88cb4e6f..e940b442 100644
--- a/lerobot/common/robot_devices/robots/configs.py
+++ b/lerobot/common/robot_devices/robots/configs.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import abc
from dataclasses import dataclass, field
from typing import Sequence
diff --git a/lerobot/common/robot_devices/robots/dynamixel_calibration.py b/lerobot/common/robot_devices/robots/dynamixel_calibration.py
index 5c4932d2..98fe8754 100644
--- a/lerobot/common/robot_devices/robots/dynamixel_calibration.py
+++ b/lerobot/common/robot_devices/robots/dynamixel_calibration.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Logic to calibrate a robot arm built with dynamixel motors"""
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
@@ -87,7 +101,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
- # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
+ # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
@@ -115,7 +129,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml?
if robot_type in ["aloha"] and "gripper" in arm.motor_names:
- # Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
+ # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
calib_idx = arm.motor_names.index("gripper")
calib_mode[calib_idx] = CalibrationMode.LINEAR.name
diff --git a/lerobot/common/robot_devices/robots/feetech_calibration.py b/lerobot/common/robot_devices/robots/feetech_calibration.py
index b015951a..2c1e7180 100644
--- a/lerobot/common/robot_devices/robots/feetech_calibration.py
+++ b/lerobot/common/robot_devices/robots/feetech_calibration.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Logic to calibrate a robot arm built with feetech motors"""
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
@@ -443,7 +457,7 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
- # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
+ # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
diff --git a/lerobot/common/robot_devices/robots/lekiwi_remote.py b/lerobot/common/robot_devices/robots/lekiwi_remote.py
index fd9491fa..7bf52d21 100644
--- a/lerobot/common/robot_devices/robots/lekiwi_remote.py
+++ b/lerobot/common/robot_devices/robots/lekiwi_remote.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import base64
import json
import threading
diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py
index 9c82d069..9173abc6 100644
--- a/lerobot/common/robot_devices/robots/manipulator.py
+++ b/lerobot/common/robot_devices/robots/manipulator.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Contains logic to instantiate a robot, read information from its motors and cameras,
and send orders to its motors.
"""
@@ -44,7 +58,7 @@ class ManipulatorRobot:
# TODO(rcadene): Implement force feedback
"""This class allows to control any manipulator robot of various number of motors.
- Non exaustive list of robots:
+ Non exhaustive list of robots:
- [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, developed
by Alexander Koch from [Tau Robotics](https://tau-robotics.com)
- [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss
@@ -55,7 +69,7 @@ class ManipulatorRobot:
robot = ManipulatorRobot(KochRobotConfig())
```
- Example of overwritting motors during instantiation:
+ Example of overwriting motors during instantiation:
```python
# Defines how to communicate with the motors of the leader and follower arms
leader_arms = {
@@ -90,7 +104,7 @@ class ManipulatorRobot:
robot = ManipulatorRobot(robot_config)
```
- Example of overwritting cameras during instantiation:
+ Example of overwriting cameras during instantiation:
```python
# Defines how to communicate with 2 cameras connected to the computer.
# Here, the webcam of the laptop and the phone (connected in USB to the laptop)
@@ -348,7 +362,7 @@ class ManipulatorRobot:
set_operating_mode_(self.follower_arms[name])
# Set better PID values to close the gap between recorded states and actions
- # TODO(rcadene): Implement an automatic procedure to set optimial PID values for each motor
+ # TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex")
self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex")
self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex")
@@ -460,7 +474,7 @@ class ManipulatorRobot:
# Used when record_data=True
follower_goal_pos[name] = goal_pos
- goal_pos = goal_pos.numpy().astype(np.int32)
+ goal_pos = goal_pos.numpy().astype(np.float32)
self.follower_arms[name].write("Goal_Position", goal_pos)
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
@@ -500,7 +514,7 @@ class ManipulatorRobot:
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
- # Populate output dictionnaries
+ # Populate output dictionaries
obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = state
action_dict["action"] = action
@@ -540,7 +554,7 @@ class ManipulatorRobot:
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
- # Populate output dictionnaries and format to pytorch
+ # Populate output dictionaries and format to pytorch
obs_dict = {}
obs_dict["observation.state"] = state
for name in self.cameras:
@@ -582,7 +596,7 @@ class ManipulatorRobot:
action_sent.append(goal_pos)
# Send goal position to each follower
- goal_pos = goal_pos.numpy().astype(np.int32)
+ goal_pos = goal_pos.numpy().astype(np.float32)
self.follower_arms[name].write("Goal_Position", goal_pos)
return torch.cat(action_sent)
diff --git a/lerobot/common/robot_devices/robots/mobile_manipulator.py b/lerobot/common/robot_devices/robots/mobile_manipulator.py
index b20c61f7..385e218b 100644
--- a/lerobot/common/robot_devices/robots/mobile_manipulator.py
+++ b/lerobot/common/robot_devices/robots/mobile_manipulator.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import base64
import json
import os
@@ -392,21 +406,19 @@ class MobileManipulator:
for name in self.leader_arms:
pos = self.leader_arms[name].read("Present_Position")
pos_tensor = torch.from_numpy(pos).float()
- # Instead of pos_tensor.item(), use tolist() to convert the entire tensor to a list
arm_positions.extend(pos_tensor.tolist())
- # (The rest of your code for generating wheel commands remains unchanged)
- x_cmd = 0.0 # m/s forward/backward
- y_cmd = 0.0 # m/s lateral
+ y_cmd = 0.0 # m/s forward/backward
+ x_cmd = 0.0 # m/s lateral
theta_cmd = 0.0 # deg/s rotation
if self.pressed_keys["forward"]:
- x_cmd += xy_speed
- if self.pressed_keys["backward"]:
- x_cmd -= xy_speed
- if self.pressed_keys["left"]:
y_cmd += xy_speed
- if self.pressed_keys["right"]:
+ if self.pressed_keys["backward"]:
y_cmd -= xy_speed
+ if self.pressed_keys["left"]:
+ x_cmd += xy_speed
+ if self.pressed_keys["right"]:
+ x_cmd -= xy_speed
if self.pressed_keys["rotate_left"]:
theta_cmd += theta_speed
if self.pressed_keys["rotate_right"]:
@@ -584,8 +596,8 @@ class MobileManipulator:
# Create the body velocity vector [x, y, theta_rad].
velocity_vector = np.array([x_cmd, y_cmd, theta_rad])
- # Define the wheel mounting angles with a -90° offset.
- angles = np.radians(np.array([240, 120, 0]) - 90)
+ # Define the wheel mounting angles (defined from y axis cw)
+ angles = np.radians(np.array([300, 180, 60]))
# Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed.
# The third column (base_radius) accounts for the effect of rotation.
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
@@ -641,8 +653,8 @@ class MobileManipulator:
# Compute each wheel’s linear speed (m/s) from its angular speed.
wheel_linear_speeds = wheel_radps * wheel_radius
- # Define the wheel mounting angles with a -90° offset.
- angles = np.radians(np.array([240, 120, 0]) - 90)
+ # Define the wheel mounting angles (defined from y axis cw)
+ angles = np.radians(np.array([300, 180, 60]))
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
# Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds.
diff --git a/lerobot/common/robot_devices/robots/stretch.py b/lerobot/common/robot_devices/robots/stretch.py
index b63bf941..9cfe6e49 100644
--- a/lerobot/common/robot_devices/robots/stretch.py
+++ b/lerobot/common/robot_devices/robots/stretch.py
@@ -108,7 +108,7 @@ class StretchRobot(StretchAPI):
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
- # Populate output dictionnaries
+ # Populate output dictionaries
obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = state
action_dict["action"] = action
@@ -153,7 +153,7 @@ class StretchRobot(StretchAPI):
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
- # Populate output dictionnaries
+ # Populate output dictionaries
obs_dict = {}
obs_dict["observation.state"] = state
for name in self.cameras:
diff --git a/lerobot/common/robot_devices/robots/utils.py b/lerobot/common/robot_devices/robots/utils.py
index 47e2519b..dab514d5 100644
--- a/lerobot/common/robot_devices/robots/utils.py
+++ b/lerobot/common/robot_devices/robots/utils.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import Protocol
from lerobot.common.robot_devices.robots.configs import (
diff --git a/lerobot/common/robot_devices/utils.py b/lerobot/common/robot_devices/utils.py
index 19bb637e..837c9d2e 100644
--- a/lerobot/common/robot_devices/utils.py
+++ b/lerobot/common/robot_devices/utils.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import platform
import time
diff --git a/lerobot/common/utils/hub.py b/lerobot/common/utils/hub.py
index 63fcf918..df7435c0 100644
--- a/lerobot/common/utils/hub.py
+++ b/lerobot/common/utils/hub.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Type, TypeVar
diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py
index 015d1ede..563a7b81 100644
--- a/lerobot/common/utils/utils.py
+++ b/lerobot/common/utils/utils.py
@@ -17,10 +17,12 @@ import logging
import os
import os.path as osp
import platform
+import subprocess
from copy import copy
from datetime import datetime, timezone
from pathlib import Path
+import numpy as np
import torch
@@ -49,8 +51,10 @@ def auto_select_torch_device() -> torch.device:
return torch.device("cpu")
+# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
+ try_device = str(try_device)
match try_device:
case "cuda":
assert torch.cuda.is_available()
@@ -83,6 +87,7 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
def is_torch_device_available(try_device: str) -> bool:
+ try_device = str(try_device) # Ensure try_device is a string
if try_device == "cuda":
return torch.cuda.is_available()
elif try_device == "mps":
@@ -90,7 +95,7 @@ def is_torch_device_available(try_device: str) -> bool:
elif try_device == "cpu":
return True
else:
- raise ValueError(f"Unknown device '{try_device}.")
+ raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
def is_amp_available(device: str):
@@ -164,23 +169,31 @@ def capture_timestamp_utc():
def say(text, blocking=False):
- # Check if mac, linux, or windows.
- if platform.system() == "Darwin":
- cmd = f'say "{text}"'
- if not blocking:
- cmd += " &"
- elif platform.system() == "Linux":
- cmd = f'spd-say "{text}"'
- if blocking:
- cmd += " --wait"
- elif platform.system() == "Windows":
- # TODO(rcadene): Make blocking option work for Windows
- cmd = (
- 'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
- f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
- )
+ system = platform.system()
- os.system(cmd)
+ if system == "Darwin":
+ cmd = ["say", text]
+
+ elif system == "Linux":
+ cmd = ["spd-say", text]
+ if blocking:
+ cmd.append("--wait")
+
+ elif system == "Windows":
+ cmd = [
+ "PowerShell",
+ "-Command",
+ "Add-Type -AssemblyName System.Speech; "
+ f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')",
+ ]
+
+ else:
+ raise RuntimeError("Unsupported operating system for text-to-speech.")
+
+ if blocking:
+ subprocess.run(cmd, check=True)
+ else:
+ subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
def log_say(text, play_sounds, blocking=False):
@@ -200,5 +213,18 @@ def get_channel_first_image_shape(image_shape: tuple) -> tuple:
return shape
-def has_method(cls: object, method_name: str):
+def has_method(cls: object, method_name: str) -> bool:
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
+
+
+def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
+ """
+ Return True if a given string can be converted to a numpy dtype.
+ """
+ try:
+ # Attempt to convert the string to a numpy dtype
+ np.dtype(dtype_str)
+ return True
+ except TypeError:
+ # If a TypeError is raised, the string is not a valid dtype
+ return False
diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py
index 9985b894..3fe241d4 100644
--- a/lerobot/common/utils/wandb_utils.py
+++ b/lerobot/common/utils/wandb_utils.py
@@ -69,7 +69,13 @@ class WandBLogger:
os.environ["WANDB_SILENT"] = "True"
import wandb
- wandb_run_id = get_wandb_run_id_from_filesystem(self.log_dir) if cfg.resume else None
+ wandb_run_id = (
+ cfg.wandb.run_id
+ if cfg.wandb.run_id
+ else get_wandb_run_id_from_filesystem(self.log_dir)
+ if cfg.resume
+ else None
+ )
wandb.init(
id=wandb_run_id,
project=self.cfg.project,
@@ -84,6 +90,7 @@ class WandBLogger:
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
job_type="train_eval",
resume="must" if cfg.resume else None,
+ mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
diff --git a/lerobot/configs/default.py b/lerobot/configs/default.py
index 5dd2f898..ce72466a 100644
--- a/lerobot/configs/default.py
+++ b/lerobot/configs/default.py
@@ -20,6 +20,7 @@ from lerobot.common import (
policies, # noqa: F401
)
from lerobot.common.datasets.transforms import ImageTransformsConfig
+from lerobot.common.datasets.video_utils import get_safe_default_codec
@dataclass
@@ -27,13 +28,15 @@ class DatasetConfig:
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
- # datsets are provided.
+ # datasets are provided.
repo_id: str
+ # Root directory where the dataset will be stored (e.g. 'dataset/path').
+ root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
- local_files_only: bool = False
+ revision: str | None = None
use_imagenet_stats: bool = True
- video_backend: str = "pyav"
+ video_backend: str = field(default_factory=get_safe_default_codec)
@dataclass
@@ -44,6 +47,8 @@ class WandBConfig:
project: str = "lerobot"
entity: str | None = None
notes: str | None = None
+ run_id: str | None = None
+ mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
@dataclass
diff --git a/lerobot/configs/eval.py b/lerobot/configs/eval.py
index 11873352..16b35291 100644
--- a/lerobot/configs/eval.py
+++ b/lerobot/configs/eval.py
@@ -1,14 +1,26 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import datetime as dt
import logging
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.common import envs, policies # noqa: F401
-from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs import parser
from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.train import TrainPipelineConfig
@dataclass
@@ -21,11 +33,6 @@ class EvalPipelineConfig:
policy: PreTrainedConfig | None = None
output_dir: Path | None = None
job_name: str | None = None
- # TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
- device: str | None = None # cuda | cpu | mps
- # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
- # automatic gradient scaling is used.
- use_amp: bool = False
seed: int | None = 1000
def __post_init__(self):
@@ -36,27 +43,6 @@ class EvalPipelineConfig:
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
- # When no device or use_amp are given, use the one from training config.
- if self.device is None or self.use_amp is None:
- train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
- if self.device is None:
- self.device = train_cfg.device
- if self.use_amp is None:
- self.use_amp = train_cfg.use_amp
-
- # Automatically switch to available device if necessary
- if not is_torch_device_available(self.device):
- auto_device = auto_select_torch_device()
- logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
- self.device = auto_device
-
- # Automatically deactivate AMP if necessary
- if self.use_amp and not is_amp_available(self.device):
- logging.warning(
- f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
- )
- self.use_amp = False
-
else:
logging.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
@@ -73,11 +59,6 @@ class EvalPipelineConfig:
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/eval") / eval_dir
- if self.device is None:
- raise ValueError("Set one of the following device: cuda, cpu or mps")
- elif self.device == "cuda" and self.use_amp is None:
- raise ValueError("Set 'use_amp' to True or False.")
-
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
diff --git a/lerobot/configs/parser.py b/lerobot/configs/parser.py
index ee784877..39e31515 100644
--- a/lerobot/configs/parser.py
+++ b/lerobot/configs/parser.py
@@ -1,4 +1,19 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
import inspect
+import pkgutil
import sys
from argparse import ArgumentError
from functools import wraps
@@ -10,6 +25,7 @@ import draccus
from lerobot.common.utils.utils import has_method
PATH_KEY = "path"
+PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
draccus.set_config_type("json")
@@ -45,6 +61,86 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
return None
+def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
+ """Parse plugin-related arguments from command-line arguments.
+
+ This function extracts arguments from command-line arguments that match a specified suffix pattern.
+ It processes arguments in the format '--key=value' and returns them as a dictionary.
+
+ Args:
+ plugin_arg_suffix (str): The suffix to identify plugin-related arguments.
+ cli_args (Sequence[str]): A sequence of command-line arguments to parse.
+
+ Returns:
+ dict: A dictionary containing the parsed plugin arguments where:
+ - Keys are the argument names (with '--' prefix removed if present)
+ - Values are the corresponding argument values
+
+ Example:
+ >>> args = ['--env.discover_packages_path=my_package',
+ ... '--other_arg=value']
+ >>> parse_plugin_args('discover_packages_path', args)
+ {'env.discover_packages_path': 'my_package'}
+ """
+ plugin_args = {}
+ for arg in args:
+ if "=" in arg and plugin_arg_suffix in arg:
+ key, value = arg.split("=", 1)
+ # Remove leading '--' if present
+ if key.startswith("--"):
+ key = key[2:]
+ plugin_args[key] = value
+ return plugin_args
+
+
+class PluginLoadError(Exception):
+ """Raised when a plugin fails to load."""
+
+
+def load_plugin(plugin_path: str) -> None:
+ """Load and initialize a plugin from a given Python package path.
+
+ This function attempts to load a plugin by importing its package and any submodules.
+ Plugin registration is expected to happen during package initialization, i.e. when
+ the package is imported the gym environment should be registered and the config classes
+ registered with their parents using the `register_subclass` decorator.
+
+ Args:
+ plugin_path (str): The Python package path to the plugin (e.g. "mypackage.plugins.myplugin")
+
+ Raises:
+ PluginLoadError: If the plugin cannot be loaded due to import errors or if the package path is invalid.
+
+ Examples:
+ >>> load_plugin("external_plugin.core") # Loads plugin from external package
+
+ Notes:
+ - The plugin package should handle its own registration during import
+ - All submodules in the plugin package will be imported
+ - Implementation follows the plugin discovery pattern from Python packaging guidelines
+
+ See Also:
+ https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/
+ """
+ try:
+ package_module = importlib.import_module(plugin_path, __package__)
+ except (ImportError, ModuleNotFoundError) as e:
+ raise PluginLoadError(
+ f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
+ ) from e
+
+ def iter_namespace(ns_pkg):
+ return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
+
+ try:
+ for _finder, pkg_name, _ispkg in iter_namespace(package_module):
+ importlib.import_module(pkg_name)
+ except ImportError as e:
+ raise PluginLoadError(
+ f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
+ ) from e
+
+
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
return parse_arg(f"{field_name}.{PATH_KEY}", args)
@@ -92,10 +188,13 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
def wrap(config_path: Path | None = None):
"""
- HACK: Similar to draccus.wrap but does two additional things:
+ HACK: Similar to draccus.wrap but does three additional things:
- Will remove '.path' arguments from CLI in order to process them later on.
- If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will
initialize it from there to allow to fetch configs from the hub directly
+ - Will load plugins specified in the CLI arguments. These plugins will typically register
+ their own subclasses of config classes, so that draccus can find the right class to instantiate
+ from the CLI '.type' arguments
"""
def wrapper_outer(fn):
@@ -108,6 +207,14 @@ def wrap(config_path: Path | None = None):
args = args[1:]
else:
cli_args = sys.argv[1:]
+ plugin_args = parse_plugin_args(PLUGIN_DISCOVERY_SUFFIX, cli_args)
+ for plugin_cli_arg, plugin_path in plugin_args.items():
+ try:
+ load_plugin(plugin_path)
+ except PluginLoadError as e:
+ # add the relevant CLI arg to the error message
+ raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
+ cli_args = filter_arg(plugin_cli_arg, cli_args)
config_path_cli = parse_arg("config_path", cli_args)
if has_method(argtype, "__get_path_fields__"):
path_fields = argtype.__get_path_fields__()
diff --git a/lerobot/configs/policies.py b/lerobot/configs/policies.py
index 9b5a7c5c..022d1fb5 100644
--- a/lerobot/configs/policies.py
+++ b/lerobot/configs/policies.py
@@ -1,4 +1,18 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import abc
+import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
@@ -12,6 +26,7 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot.common.optim.optimizers import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin
+from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
# Generic variable that is either PreTrainedConfig or a subclass thereof
@@ -40,8 +55,24 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
+ device: str | None = None # cuda | cpu | mp
+ # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
+ # automatic gradient scaling is used.
+ use_amp: bool = False
+
def __post_init__(self):
self.pretrained_path = None
+ if not self.device or not is_torch_device_available(self.device):
+ auto_device = auto_select_torch_device()
+ logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
+ self.device = auto_device.type
+
+ # Automatically deactivate AMP if necessary
+ if self.use_amp and not is_amp_available(self.device):
+ logging.warning(
+ f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
+ )
+ self.use_amp = False
@property
def type(self) -> str:
diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py
index 93f6e2a4..7a787b83 100644
--- a/lerobot/configs/train.py
+++ b/lerobot/configs/train.py
@@ -1,5 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import datetime as dt
-import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
@@ -13,7 +25,6 @@ from lerobot.common import envs
from lerobot.common.optim import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin
-from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig
@@ -35,10 +46,6 @@ class TrainPipelineConfig(HubMixin):
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
# regardless of what's provided with the training command at the time of resumption.
resume: bool = False
- device: str | None = None # cuda | cpu | mp
- # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
- # automatic gradient scaling is used.
- use_amp: bool = False
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: int | None = 1000
@@ -61,18 +68,6 @@ class TrainPipelineConfig(HubMixin):
self.checkpoint_path = None
def validate(self):
- if not self.device:
- logging.warning("No device specified, trying to infer device automatically")
- device = auto_select_torch_device()
- self.device = device.type
-
- # Automatically deactivate AMP if necessary
- if self.use_amp and not is_amp_available(self.device):
- logging.warning(
- f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
- )
- self.use_amp = False
-
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
if policy_path:
@@ -84,7 +79,9 @@ class TrainPipelineConfig(HubMixin):
# The entire train config is already loaded, we just need to get the checkpoint dir
config_path = parser.parse_arg("config_path")
if not config_path:
- raise ValueError("A config_path is expected when resuming a run.")
+ raise ValueError(
+ f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
+ )
if not Path(config_path).resolve().exists():
raise NotADirectoryError(
f"{config_path=} is expected to be a local path. "
@@ -102,7 +99,7 @@ class TrainPipelineConfig(HubMixin):
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
raise FileExistsError(
- f"Output directory {self.output_dir} alreay exists and resume is {self.resume}. "
+ f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
f"Please change your output directory so that {self.output_dir} is not overwritten."
)
elif not self.output_dir:
diff --git a/lerobot/configs/types.py b/lerobot/configs/types.py
index 0ca45a19..6b3d92e8 100644
--- a/lerobot/configs/types.py
+++ b/lerobot/configs/types.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
# Note: We subclass str so that serialization is straightforward
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
from dataclasses import dataclass
diff --git a/lerobot/scripts/configure_motor.py b/lerobot/scripts/configure_motor.py
index f7e07070..b0dc8a97 100644
--- a/lerobot/scripts/configure_motor.py
+++ b/lerobot/scripts/configure_motor.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
This script configure a single motor at a time to a given ID and baudrate.
diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py
index 9129c9e3..3c3c43f9 100644
--- a/lerobot/scripts/control_robot.py
+++ b/lerobot/scripts/control_robot.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
Utilities to control a robot.
@@ -92,7 +105,6 @@ python lerobot/scripts/control_robot.py \
This might require a sudo permission to allow your terminal to monitor keyboard events.
**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--control.resume=true`.
-If the dataset you want to extend is not on the hub, you also need to add `--control.local_files_only=true`.
- Train on this dataset with the ACT policy:
```bash
@@ -234,7 +246,6 @@ def record(
dataset = LeRobotDataset(
cfg.repo_id,
root=cfg.root,
- local_files_only=cfg.local_files_only,
)
if len(robot.cameras) > 0:
dataset.start_image_writer(
@@ -256,7 +267,7 @@ def record(
)
# Load pretrained policy
- policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
+ policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
if not robot.is_connected:
robot.connect()
@@ -281,15 +292,14 @@ def record(
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
record_episode(
- dataset=dataset,
robot=robot,
+ dataset=dataset,
events=events,
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
policy=policy,
- device=cfg.device,
- use_amp=cfg.use_amp,
fps=cfg.fps,
+ single_task=cfg.single_task,
)
# Execute a few seconds without recording to give time to manually reset the environment
@@ -300,7 +310,7 @@ def record(
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
- reset_environment(robot, events, cfg.reset_time_s)
+ reset_environment(robot, events, cfg.reset_time_s, cfg.fps)
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
@@ -309,7 +319,7 @@ def record(
dataset.clear_episode_buffer()
continue
- dataset.save_episode(cfg.single_task)
+ dataset.save_episode()
recorded_episodes += 1
if events["stop_recording"]:
@@ -318,11 +328,6 @@ def record(
log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras)
- if cfg.run_compute_stats:
- logging.info("Computing dataset statistics")
-
- dataset.consolidate(cfg.run_compute_stats)
-
if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
@@ -338,9 +343,7 @@ def replay(
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
- dataset = LeRobotDataset(
- cfg.repo_id, root=cfg.root, episodes=[cfg.episode], local_files_only=cfg.local_files_only
- )
+ dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
actions = dataset.hf_dataset.select_columns("action")
if not robot.is_connected:
diff --git a/lerobot/scripts/control_sim_robot.py b/lerobot/scripts/control_sim_robot.py
index 8d69bf31..5347822c 100644
--- a/lerobot/scripts/control_sim_robot.py
+++ b/lerobot/scripts/control_sim_robot.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
Utilities to control a robot in simulation.
@@ -59,8 +72,8 @@ python lerobot/scripts/control_sim_robot.py record \
```
**NOTE**: You can use your keyboard to control data recording flow.
-- Tap right arrow key '->' to early exit while recording an episode and go to reseting the environment.
-- Tap right arrow key '->' to early exit while reseting the environment and got to recording the next episode.
+- Tap right arrow key '->' to early exit while recording an episode and go to resetting the environment.
+- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode.
- Tap left arrow key '<-' to early exit and re-record the current episode.
- Tap escape key 'esc' to stop the data recording.
This might require a sudo permission to allow your terminal to monitor keyboard events.
@@ -131,7 +144,7 @@ def none_or_int(value):
def init_sim_calibration(robot, cfg):
# Constants necessary for transforming the joint pos of the real robot to the sim
- # depending on the robot discription used in that sim.
+ # depending on the robot description used in that sim.
start_pos = np.array(robot.leader_arms.main.calibration["start_pos"])
axis_directions = np.array(cfg.get("axis_directions", [1]))
offsets = np.array(cfg.get("offsets", [0])) * np.pi
@@ -445,7 +458,7 @@ if __name__ == "__main__":
type=int,
default=0,
help=(
- "Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only; "
+ "Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; "
"set to ≥1 to use subprocesses, each using threads to write images. The best number of processes "
"and threads depends on your system. We recommend 4 threads per camera with 0 processes. "
"If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses."
diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py
index a4f79afc..9790f8b3 100644
--- a/lerobot/scripts/eval.py
+++ b/lerobot/scripts/eval.py
@@ -66,7 +66,7 @@ from torch import Tensor, nn
from tqdm import trange
from lerobot.common.envs.factory import make_env
-from lerobot.common.envs.utils import preprocess_observation
+from lerobot.common.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters
@@ -124,7 +124,6 @@ def rollout(
# Reset the policy and environments.
policy.reset()
-
observation, info = env.reset(seed=seeds)
if render_callback is not None:
render_callback(env)
@@ -145,6 +144,7 @@ def rollout(
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
leave=False,
)
+ check_env_attributes_and_types(env)
while not np.all(done):
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
observation = preprocess_observation(observation)
@@ -155,6 +155,10 @@ def rollout(
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}
+ # Infer "task" from attributes of environments.
+ # TODO: works with SyncVectorEnv but not AsyncVectorEnv
+ observation = add_envs_task(env, observation)
+
with torch.inference_mode():
action = policy.select_action(observation)
@@ -454,11 +458,11 @@ def _compile_episode_data(
@parser.wrap()
-def eval(cfg: EvalPipelineConfig):
+def eval_main(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg)))
# Check device is available
- device = get_safe_torch_device(cfg.device, log=True)
+ device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -470,14 +474,14 @@ def eval(cfg: EvalPipelineConfig):
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Making policy.")
+
policy = make_policy(
cfg=cfg.policy,
- device=device,
env_cfg=cfg.env,
)
policy.eval()
- with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
+ with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy(
env,
policy,
@@ -499,4 +503,4 @@ def eval(cfg: EvalPipelineConfig):
if __name__ == "__main__":
init_logging()
- eval()
+ eval_main()
diff --git a/lerobot/scripts/find_motors_bus_port.py b/lerobot/scripts/find_motors_bus_port.py
index 67b92ad7..68f2315d 100644
--- a/lerobot/scripts/find_motors_bus_port.py
+++ b/lerobot/scripts/find_motors_bus_port.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import os
import time
from pathlib import Path
diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py
deleted file mode 100644
index 0233ede6..00000000
--- a/lerobot/scripts/push_dataset_to_hub.py
+++ /dev/null
@@ -1,364 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
-or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
-installation of neural net specific packages like pytorch, tensorflow, jax.
-
-Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
-```
-python lerobot/scripts/push_dataset_to_hub.py \
---raw-dir data/pusht_raw \
---raw-format pusht_zarr \
---repo-id lerobot/pusht
-
-python lerobot/scripts/push_dataset_to_hub.py \
---raw-dir data/xarm_lift_medium_raw \
---raw-format xarm_pkl \
---repo-id lerobot/xarm_lift_medium
-
-python lerobot/scripts/push_dataset_to_hub.py \
---raw-dir data/aloha_sim_insertion_scripted_raw \
---raw-format aloha_hdf5 \
---repo-id lerobot/aloha_sim_insertion_scripted
-
-python lerobot/scripts/push_dataset_to_hub.py \
---raw-dir data/umi_cup_in_the_wild_raw \
---raw-format umi_zarr \
---repo-id lerobot/umi_cup_in_the_wild
-```
-"""
-
-import argparse
-import json
-import shutil
-import warnings
-from pathlib import Path
-from typing import Any
-
-import torch
-from huggingface_hub import HfApi
-from safetensors.torch import save_file
-
-from lerobot.common.datasets.compute_stats import compute_stats
-from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
-from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
-from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict
-
-
-def get_from_raw_to_lerobot_format_fn(raw_format: str):
- if raw_format == "pusht_zarr":
- from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
- elif raw_format == "umi_zarr":
- from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
- elif raw_format == "aloha_hdf5":
- from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
- elif raw_format in ["rlds", "openx"]:
- from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
- elif raw_format == "dora_parquet":
- from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
- elif raw_format == "xarm_pkl":
- from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
- elif raw_format == "cam_png":
- from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format
- else:
- raise ValueError(
- f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
- )
-
- return from_raw_to_lerobot_format
-
-
-def save_meta_data(
- info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
-):
- meta_data_dir.mkdir(parents=True, exist_ok=True)
-
- # save info
- info_path = meta_data_dir / "info.json"
- with open(str(info_path), "w") as f:
- json.dump(info, f, indent=4)
-
- # save stats
- stats_path = meta_data_dir / "stats.safetensors"
- save_file(flatten_dict(stats), stats_path)
-
- # save episode_data_index
- episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
- ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
- save_file(episode_data_index, ep_data_idx_path)
-
-
-def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
- """Expect all meta data files to be all stored in a single "meta_data" directory.
- On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
- """
- api = HfApi()
- api.upload_folder(
- folder_path=meta_data_dir,
- path_in_repo="meta_data",
- repo_id=repo_id,
- revision=revision,
- repo_type="dataset",
- )
-
-
-def push_dataset_card_to_hub(
- repo_id: str,
- revision: str | None,
- tags: list | None = None,
- license: str = "apache-2.0",
- **card_kwargs,
-):
- """Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
- card = create_lerobot_dataset_card(tags=tags, license=license, **card_kwargs)
- card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
-
-
-def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
- """Expect mp4 files to be all stored in a single "videos" directory.
- On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
- """
- api = HfApi()
- api.upload_folder(
- folder_path=videos_dir,
- path_in_repo="videos",
- repo_id=repo_id,
- revision=revision,
- repo_type="dataset",
- allow_patterns="*.mp4",
- )
-
-
-def push_dataset_to_hub(
- raw_dir: Path,
- raw_format: str,
- repo_id: str,
- push_to_hub: bool = True,
- local_dir: Path | None = None,
- fps: int | None = None,
- video: bool = True,
- batch_size: int = 32,
- num_workers: int = 8,
- episodes: list[int] | None = None,
- force_override: bool = False,
- resume: bool = False,
- cache_dir: Path = Path("/tmp"),
- tests_data_dir: Path | None = None,
- encoding: dict | None = None,
-):
- check_repo_id(repo_id)
- user_id, dataset_id = repo_id.split("/")
-
- # Robustify when `raw_dir` is str instead of Path
- raw_dir = Path(raw_dir)
- if not raw_dir.exists():
- raise NotADirectoryError(
- f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub: "
- f"`python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw`"
- )
-
- if local_dir:
- # Robustify when `local_dir` is str instead of Path
- local_dir = Path(local_dir)
-
- # Send warning if local_dir isn't well formated
- if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
- warnings.warn(
- f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
- stacklevel=1,
- )
-
- # Check we don't override an existing `local_dir` by mistake
- if local_dir.exists():
- if force_override:
- shutil.rmtree(local_dir)
- elif not resume:
- raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
-
- meta_data_dir = local_dir / "meta_data"
- videos_dir = local_dir / "videos"
- else:
- # Temporary directory used to store images, videos, meta_data
- meta_data_dir = Path(cache_dir) / "meta_data"
- videos_dir = Path(cache_dir) / "videos"
-
- if raw_format is None:
- # TODO(rcadene, adilzouitine): implement auto_find_raw_format
- raise NotImplementedError()
- # raw_format = auto_find_raw_format(raw_dir)
-
- # convert dataset from original raw format to LeRobot format
- from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
-
- hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
- raw_dir,
- videos_dir,
- fps,
- video,
- episodes,
- encoding,
- )
-
- lerobot_dataset = LeRobotDataset.from_preloaded(
- repo_id=repo_id,
- hf_dataset=hf_dataset,
- episode_data_index=episode_data_index,
- info=info,
- videos_dir=videos_dir,
- )
- stats = compute_stats(lerobot_dataset, batch_size, num_workers)
-
- if local_dir:
- hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
- hf_dataset.save_to_disk(str(local_dir / "train"))
-
- if push_to_hub or local_dir:
- # mandatory for upload
- save_meta_data(info, stats, episode_data_index, meta_data_dir)
-
- if push_to_hub:
- hf_dataset.push_to_hub(repo_id, revision="main")
- push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
- push_dataset_card_to_hub(repo_id, revision="main")
- if video:
- push_videos_to_hub(repo_id, videos_dir, revision="main")
- create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
-
- if tests_data_dir:
- # get the first episode
- num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
- test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
- episode_data_index = {k: v[:1] for k, v in episode_data_index.items()}
-
- test_hf_dataset = test_hf_dataset.with_format(None)
- test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
-
- tests_meta_data = tests_data_dir / repo_id / "meta_data"
- save_meta_data(info, stats, episode_data_index, tests_meta_data)
-
- # copy videos of first episode to tests directory
- episode_index = 0
- tests_videos_dir = tests_data_dir / repo_id / "videos"
- tests_videos_dir.mkdir(parents=True, exist_ok=True)
- for key in lerobot_dataset.camera_keys:
- fname = f"{key}_episode_{episode_index:06d}.mp4"
- shutil.copy(videos_dir / fname, tests_videos_dir / fname)
-
- if local_dir is None:
- # clear cache
- shutil.rmtree(meta_data_dir)
- shutil.rmtree(videos_dir)
-
- return lerobot_dataset
-
-
-def main():
- parser = argparse.ArgumentParser()
-
- parser.add_argument(
- "--raw-dir",
- type=Path,
- required=True,
- help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
- )
- # TODO(rcadene): add automatic detection of the format
- parser.add_argument(
- "--raw-format",
- type=str,
- required=True,
- help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `rlds`, `openx`).",
- )
- parser.add_argument(
- "--repo-id",
- type=str,
- required=True,
- help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
- )
- parser.add_argument(
- "--local-dir",
- type=Path,
- help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
- )
- parser.add_argument(
- "--push-to-hub",
- type=int,
- default=1,
- help="Upload to hub.",
- )
- parser.add_argument(
- "--fps",
- type=int,
- help="Frame rate used to collect videos. If not provided, use the default one specified in the code.",
- )
- parser.add_argument(
- "--video",
- type=int,
- default=1,
- help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
- )
- parser.add_argument(
- "--batch-size",
- type=int,
- default=32,
- help="Batch size loaded by DataLoader for computing the dataset statistics.",
- )
- parser.add_argument(
- "--num-workers",
- type=int,
- default=8,
- help="Number of processes of Dataloader for computing the dataset statistics.",
- )
- parser.add_argument(
- "--episodes",
- type=int,
- nargs="*",
- help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
- )
- parser.add_argument(
- "--force-override",
- type=int,
- default=0,
- help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
- )
- parser.add_argument(
- "--resume",
- type=int,
- default=0,
- help="When set to 1, resumes a previous run.",
- )
- parser.add_argument(
- "--cache-dir",
- type=Path,
- required=False,
- default="/tmp",
- help="Directory to store the temporary videos and images generated while creating the dataset.",
- )
- parser.add_argument(
- "--tests-data-dir",
- type=Path,
- help=(
- "When provided, save tests artifacts into the given directory "
- "(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
- ),
- )
-
- args = parser.parse_args()
- push_dataset_to_hub(**vars(args))
-
-
-if __name__ == "__main__":
- main()
diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py
index f3c57fe2..0de247be 100644
--- a/lerobot/scripts/train.py
+++ b/lerobot/scripts/train.py
@@ -72,7 +72,7 @@ def update_policy(
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
grad_scaler.scale(loss).backward()
- # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
+ # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
@@ -120,7 +120,7 @@ def train(cfg: TrainPipelineConfig):
set_seed(cfg.seed)
# Check device is available
- device = get_safe_torch_device(cfg.device, log=True)
+ device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -133,18 +133,17 @@ def train(cfg: TrainPipelineConfig):
eval_env = None
if cfg.eval_freq > 0 and cfg.env is not None:
logging.info("Creating env")
- eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
+ eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
- device=device,
ds_meta=dataset.meta,
)
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
- grad_scaler = GradScaler(device, enabled=cfg.use_amp)
+ grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
@@ -218,7 +217,7 @@ def train(cfg: TrainPipelineConfig):
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
- use_amp=cfg.use_amp,
+ use_amp=cfg.policy.use_amp,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -249,7 +248,10 @@ def train(cfg: TrainPipelineConfig):
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
- with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
+ with (
+ torch.no_grad(),
+ torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
+ ):
eval_info = eval_policy(
eval_env,
policy,
diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py
index 626b0bde..cdfea6b8 100644
--- a/lerobot/scripts/visualize_dataset.py
+++ b/lerobot/scripts/visualize_dataset.py
@@ -207,12 +207,6 @@ def main():
required=True,
help="Episode to visualize.",
)
- parser.add_argument(
- "--local-files-only",
- type=int,
- default=0,
- help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
- )
parser.add_argument(
"--root",
type=Path,
@@ -271,14 +265,25 @@ def main():
),
)
+ parser.add_argument(
+ "--tolerance-s",
+ type=float,
+ default=1e-4,
+ help=(
+ "Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
+ "This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
+ "If not given, defaults to 1e-4."
+ ),
+ )
+
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
- local_files_only = kwargs.pop("local_files_only")
+ tolerance_s = kwargs.pop("tolerance_s")
logging.info("Loading dataset")
- dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
+ dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
visualize_dataset(dataset, **vars(args))
diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py
index cc3f3930..0fc21a8f 100644
--- a/lerobot/scripts/visualize_dataset_html.py
+++ b/lerobot/scripts/visualize_dataset_html.py
@@ -150,7 +150,7 @@ def run_server(
400,
)
dataset_version = (
- dataset.meta._version if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
+ str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
)
match = re.search(r"v(\d+)\.", dataset_version)
if match:
@@ -158,7 +158,7 @@ def run_server(
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
- episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
+ episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
@@ -194,7 +194,7 @@ def run_server(
]
response = requests.get(
- f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl"
+ f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
)
response.raise_for_status()
# Split into lines and parse each line as JSON
@@ -218,6 +218,7 @@ def run_server(
videos_info=videos_info,
episode_data_csv_str=episode_data_csv_str,
columns=columns,
+ ignored_columns=ignored_columns,
)
app.run(host=host, port=port)
@@ -233,9 +234,17 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
- selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
+ selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
selected_columns.remove("timestamp")
+ ignored_columns = []
+ for column_name in selected_columns:
+ shape = dataset.features[column_name]["shape"]
+ shape_dim = len(shape)
+ if shape_dim > 1:
+ selected_columns.remove(column_name)
+ ignored_columns.append(column_name)
+
# init header of csv with state and action names
header = ["timestamp"]
@@ -245,16 +254,17 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
if isinstance(dataset, LeRobotDataset)
else dataset.features[column_name].shape[0]
)
- header += [f"{column_name}_{i}" for i in range(dim_state)]
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
else:
- column_names = [f"motor_{i}" for i in range(dim_state)]
+ column_names = [f"{column_name}_{i}" for i in range(dim_state)]
columns.append({"key": column_name, "value": column_names})
+ header += column_names
+
selected_columns.insert(0, "timestamp")
if isinstance(dataset, LeRobotDataset):
@@ -290,7 +300,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
csv_writer.writerows(rows)
csv_string = csv_buffer.getvalue()
- return csv_string, columns
+ return csv_string, columns, ignored_columns
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
@@ -317,7 +327,9 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
def get_dataset_info(repo_id: str) -> IterableNamespace:
- response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json")
+ response = requests.get(
+ f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
+ )
response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json()
dataset_info["repo_id"] = repo_id
@@ -364,7 +376,7 @@ def visualize_dataset_html(
template_folder=template_dir,
)
else:
- # Create a simlink from the dataset video folder containg mp4 files to the output directory
+ # Create a simlink from the dataset video folder containing mp4 files to the output directory
# so that the http server can get access to the mp4 files.
if isinstance(dataset, LeRobotDataset):
ln_videos_dir = static_dir / "videos"
@@ -384,12 +396,6 @@ def main():
default=None,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
- parser.add_argument(
- "--local-files-only",
- type=int,
- default=0,
- help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
- )
parser.add_argument(
"--root",
type=Path,
@@ -440,17 +446,28 @@ def main():
help="Delete the output directory if it exists already.",
)
+ parser.add_argument(
+ "--tolerance-s",
+ type=float,
+ default=1e-4,
+ help=(
+ "Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
+ "This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
+ "If not given, defaults to 1e-4."
+ ),
+ )
+
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
root = kwargs.pop("root")
- local_files_only = kwargs.pop("local_files_only")
+ tolerance_s = kwargs.pop("tolerance_s")
dataset = None
if repo_id:
dataset = (
- LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
+ LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
if not load_from_hf_hub
else get_dataset_info(repo_id)
)
diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py
index 727fe178..80935d32 100644
--- a/lerobot/scripts/visualize_image_transforms.py
+++ b/lerobot/scripts/visualize_image_transforms.py
@@ -109,7 +109,7 @@ def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR
dataset = LeRobotDataset(
repo_id=cfg.repo_id,
episodes=cfg.episodes,
- local_files_only=cfg.local_files_only,
+ revision=cfg.revision,
video_backend=cfg.video_backend,
)
diff --git a/lerobot/templates/visualize_dataset_template.html b/lerobot/templates/visualize_dataset_template.html
index 08de3e3d..cf9d40f1 100644
--- a/lerobot/templates/visualize_dataset_template.html
+++ b/lerobot/templates/visualize_dataset_template.html
@@ -14,21 +14,7 @@
- {
- // Use the space bar to play and pause, instead of default action (e.g. scrolling)
- const { keyCode, key } = e;
- if (keyCode === 32 || key === ' ') {
- e.preventDefault();
- $refs.btnPause.classList.contains('hidden') ? $refs.btnPlay.click() : $refs.btnPause.click();
- }else if (key === 'ArrowDown' || key === 'ArrowUp'){
- const nextEpisodeId = key === 'ArrowDown' ? {{ episode_id }} + 1 : {{ episode_id }} - 1;
- const lowestEpisodeId = {{ episodes }}.at(0);
- const highestEpisodeId = {{ episodes }}.at(-1);
- if(nextEpisodeId >= lowestEpisodeId && nextEpisodeId <= highestEpisodeId){
- window.location.href = `./episode_${nextEpisodeId}`;
- }
- }
-}">
+