Merge remote-tracking branch 'origin/main' into user/aliberts/2024_05_06_add_coverage

This commit is contained in:
Simon Alibert 2024-05-23 14:39:41 +02:00
commit fc07f0e2bc
648 changed files with 4262 additions and 1898 deletions

4
.gitattributes vendored
View File

@ -1,2 +1,6 @@
*.memmap filter=lfs diff=lfs merge=lfs -text *.memmap filter=lfs diff=lfs merge=lfs -text
*.stl filter=lfs diff=lfs merge=lfs -text *.stl filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
*.mp4 filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.json filter=lfs diff=lfs merge=lfs -text

View File

@ -10,7 +10,6 @@ on:
env: env:
PYTHON_VERSION: "3.10" PYTHON_VERSION: "3.10"
# CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
jobs: jobs:
latest-cpu: latest-cpu:
@ -51,30 +50,6 @@ jobs:
tags: huggingface/lerobot-cpu tags: huggingface/lerobot-cpu
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }} build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
# - name: Post to a Slack channel
# id: slack
# #uses: slackapi/slack-github-action@v1.25.0
# uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
# with:
# # Slack channel id, channel name, or user id to post message.
# # See also: https://api.slack.com/methods/chat.postMessage#channels
# channel-id: ${{ env.CI_SLACK_CHANNEL }}
# # For posting a rich message using Block Kit
# payload: |
# {
# "text": "lerobot-cpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
# "blocks": [
# {
# "type": "section",
# "text": {
# "type": "mrkdwn",
# "text": "lerobot-cpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
# }
# }
# ]
# }
# env:
# SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
latest-cuda: latest-cuda:
name: GPU name: GPU
@ -113,27 +88,40 @@ jobs:
tags: huggingface/lerobot-gpu tags: huggingface/lerobot-gpu
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }} build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
# - name: Post to a Slack channel
# id: slack latest-cuda-dev:
# #uses: slackapi/slack-github-action@v1.25.0 name: GPU Dev
# uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001 runs-on: ubuntu-latest
# with: steps:
# # Slack channel id, channel name, or user id to post message. - name: Cleanup disk
# # See also: https://api.slack.com/methods/chat.postMessage#channels run: |
# channel-id: ${{ env.CI_SLACK_CHANNEL }} sudo df -h
# # For posting a rich message using Block Kit # sudo ls -l /usr/local/lib/
# payload: | # sudo ls -l /usr/share/
# { sudo du -sh /usr/local/lib/
# "text": "lerobot-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}", sudo du -sh /usr/share/
# "blocks": [ sudo rm -rf /usr/local/lib/android
# { sudo rm -rf /usr/share/dotnet
# "type": "section", sudo du -sh /usr/local/lib/
# "text": { sudo du -sh /usr/share/
# "type": "mrkdwn", sudo df -h
# "text": "lerobot-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}" - name: Set up Docker Buildx
# } uses: docker/setup-buildx-action@v3
# }
# ] - name: Check out code
# } uses: actions/checkout@v4
# env:
# SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} - name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU dev
uses: docker/build-push-action@v5
with:
context: .
file: ./docker/lerobot-gpu-dev/Dockerfile
push: true
tags: huggingface/lerobot-gpu:dev
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}

View File

@ -29,6 +29,8 @@ jobs:
MUJOCO_GL: egl MUJOCO_GL: egl
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with:
lfs: true # Ensure LFS files are pulled
- name: Install EGL - name: Install EGL
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
@ -57,6 +59,40 @@ jobs:
&& rm -rf tests/outputs outputs && rm -rf tests/outputs outputs
pytest-minimal:
name: Pytest (minimal install)
runs-on: ubuntu-latest
env:
DATA_DIR: tests/data
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
with:
lfs: true # Ensure LFS files are pulled
- name: Install poetry
run: |
pipx install poetry && poetry config virtualenvs.in-project true
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install poetry dependencies
run: |
poetry install --extras "test"
- name: Test with pytest
run: |
pytest tests -v --cov=./lerobot --durations=0 \
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
&& rm -rf tests/outputs outputs
end-to-end: end-to-end:
name: End-to-end name: End-to-end
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -65,6 +101,8 @@ jobs:
MUJOCO_GL: egl MUJOCO_GL: egl
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with:
lfs: true # Ensure LFS files are pulled
- name: Install EGL - name: Install EGL
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev

32
.gitignore vendored
View File

@ -2,12 +2,17 @@
logs logs
tmp tmp
wandb wandb
# Data
data data
outputs outputs
.vscode
rl # Apple
.DS_Store .DS_Store
# VS Code
.vscode
# HPC # HPC
nautilus/*.yaml nautilus/*.yaml
*.key *.key
@ -90,6 +95,7 @@ instance/
docs/_build/ docs/_build/
# PyBuilder # PyBuilder
.pybuilder/
target/ target/
# Jupyter Notebook # Jupyter Notebook
@ -102,13 +108,6 @@ ipython_config.py
# pyenv # pyenv
.python-version .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow # PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/ __pypackages__/
@ -119,6 +118,15 @@ celerybeat.pid
# SageMath parsed files # SageMath parsed files
*.sage.py *.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings # Spyder project settings
.spyderproject .spyderproject
.spyproject .spyproject
@ -136,3 +144,9 @@ dmypy.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/

View File

@ -18,7 +18,7 @@ repos:
hooks: hooks:
- id: pyupgrade - id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2 rev: v0.4.3
hooks: hooks:
- id: ruff - id: ruff
args: [--fix] args: [--fix]

View File

@ -195,6 +195,11 @@ Follow these steps to start contributing:
git commit git commit
``` ```
Note, if you already commited some changes that have a wrong formatting, you can use:
```bash
pre-commit run --all-files
```
Please write [good commit messages](https://chris.beams.io/posts/git-commit/). Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
It is a good idea to sync your copy of the code with the original It is a good idea to sync your copy of the code with the original

View File

@ -20,16 +20,19 @@ build-gpu:
test-end-to-end: test-end-to-end:
${MAKE} test-act-ete-train ${MAKE} test-act-ete-train
${MAKE} test-act-ete-eval ${MAKE} test-act-ete-eval
${MAKE} test-act-ete-train-amp
${MAKE} test-act-ete-eval-amp
${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval ${MAKE} test-diffusion-ete-eval
# TODO(rcadene, alexander-soare): enable end-to-end tests for tdmpc ${MAKE} test-tdmpc-ete-train
# ${MAKE} test-tdmpc-ete-train ${MAKE} test-tdmpc-ete-eval
# ${MAKE} test-tdmpc-ete-eval
${MAKE} test-default-ete-eval ${MAKE} test-default-ete-eval
${MAKE} test-act-pusht-tutorial
test-act-ete-train: test-act-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=act \ policy=act \
policy.dim_model=64 \
env=aloha \ env=aloha \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
@ -52,9 +55,40 @@ test-act-ete-eval:
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
test-act-ete-train-amp:
python lerobot/scripts/train.py \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=cpu \
training.save_model=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act/ \
use_amp=true
test-act-ete-eval-amp:
python lerobot/scripts/eval.py \
-p tests/outputs/act/checkpoints/000002 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=cpu \
use_amp=true
test-diffusion-ete-train: test-diffusion-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=diffusion \ policy=diffusion \
policy.down_dims=\[64,128,256\] \
policy.diffusion_step_embed_dim=32 \
policy.num_inference_steps=10 \
env=pusht \ env=pusht \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
@ -75,15 +109,16 @@ test-diffusion-ete-eval:
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
test-tdmpc-ete-train: test-tdmpc-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=tdmpc \ policy=tdmpc \
env=xarm \ env=xarm \
env.task=XarmLift-v0 \ env.task=XarmLift-v0 \
dataset_repo_id=lerobot/xarm_lift_medium_replay \ dataset_repo_id=lerobot/xarm_lift_medium \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
training.online_steps=2 \ training.online_steps=0 \
eval.n_episodes=1 \ eval.n_episodes=1 \
eval.batch_size=1 \ eval.batch_size=1 \
env.episode_length=2 \ env.episode_length=2 \
@ -101,7 +136,6 @@ test-tdmpc-ete-eval:
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
test-default-ete-eval: test-default-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \ --config lerobot/configs/default.yaml \
@ -109,3 +143,21 @@ test-default-ete-eval:
eval.batch_size=1 \ eval.batch_size=1 \
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
test-act-pusht-tutorial:
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/created_by_Makefile.yaml
python lerobot/scripts/train.py \
policy=created_by_Makefile.yaml \
env=pusht \
wandb.enable=False \
training.offline_steps=2 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=2 \
device=cpu \
training.save_model=true \
training.save_freq=2 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act_pusht/
rm lerobot/configs/policy/created_by_Makefile.yaml

View File

@ -57,7 +57,6 @@
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io). - Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io). - Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM). - Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
- Thanks to Vincent Moens and colleagues for open sourcing [TorchRL](https://github.com/pytorch/rl). It allowed for quick experimentations on the design of `LeRobot`.
- Thanks to Antonio Loquercio and Ashish Kumar for their early support. - Thanks to Antonio Loquercio and Ashish Kumar for their early support.
@ -78,6 +77,10 @@ Install 🤗 LeRobot:
pip install . pip install .
``` ```
> **NOTE:** Depending on your platform, If you encounter any build errors during this step
you may need to install `cmake` and `build-essential` for building some of our dependencies.
On linux: `sudo apt-get install cmake build-essential`
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras: For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
- [aloha](https://github.com/huggingface/gym-aloha) - [aloha](https://github.com/huggingface/gym-aloha)
- [xarm](https://github.com/huggingface/gym-xarm) - [xarm](https://github.com/huggingface/gym-xarm)
@ -93,11 +96,14 @@ To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tra
wandb login wandb login
``` ```
(note: you will also need to enable WandB in the configuration. See below.)
## Walkthrough ## Walkthrough
``` ```
. .
├── examples # contains demonstration examples, start here to learn about LeRobot ├── examples # contains demonstration examples, start here to learn about LeRobot
| └── advanced # contains even more examples for those who have mastered the basics
├── lerobot ├── lerobot
| ├── configs # contains hydra yaml files with all options that you can override in the command line | ├── configs # contains hydra yaml files with all options that you can override in the command line
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy | | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
@ -157,15 +163,16 @@ See `python lerobot/scripts/eval.py --help` for more instructions.
### Train your own policy ### Train your own policy
Check out [example 3](./examples/3_train_policy.py) that illustrates how to start training a model. Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
In general, you can use our training script to easily train any policy. To use wandb for logging training and evaluation curves, make sure you ran `wandb login`. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
```bash ```bash
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=act \ policy=act \
env=aloha \ env=aloha \
env.task=AlohaInsertion-v0 \ env.task=AlohaInsertion-v0 \
dataset_repo_id=lerobot/aloha_sim_insertion_human dataset_repo_id=lerobot/aloha_sim_insertion_human \
``` ```
The experiment directory is automatically generated and will show up in yellow in your terminal. It looks like `outputs/train/2024-05-05/20-21-12_aloha_act_default`. You can manually specify an experiment directory by adding this argument to the `train.py` python command: The experiment directory is automatically generated and will show up in yellow in your terminal. It looks like `outputs/train/2024-05-05/20-21-12_aloha_act_default`. You can manually specify an experiment directory by adding this argument to the `train.py` python command:
@ -173,17 +180,29 @@ The experiment directory is automatically generated and will show up in yellow i
hydra.run.dir=your/new/experiment/dir hydra.run.dir=your/new/experiment/dir
``` ```
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of logs from wandb: To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
![](media/wandb.png)
You can deactivate wandb by adding these arguments to the `train.py` python command:
```bash ```bash
wandb.disable_artifact=true \ wandb.enable=true
wandb.enable=false
``` ```
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. After training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions. A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser:
![](media/wandb.png)
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
#### Reproduce state-of-the-art (SOTA)
We have organized our configuration files (found under [`lerobot/configs`](./lerobot/configs)) such that they reproduce SOTA results from a given model variant in their respective original works. Simply running:
```bash
python lerobot/scripts/train.py policy=diffusion env=pusht
```
reproduces SOTA results for Diffusion Policy on the PushT task.
Pretrained policies, along with reproduction details, can be found under the "Models" section of https://huggingface.co/lerobot.
## Contribute ## Contribute
@ -196,11 +215,11 @@ To add a dataset to the hub, you need to login using a write-access token, which
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
``` ```
Then move your dataset folder in `data` directory (e.g. `data/aloha_ping_pong`), and push your dataset to the hub with: Then move your dataset folder in `data` directory (e.g. `data/aloha_static_pingpong_test`), and push your dataset to the hub with:
```bash ```bash
python lerobot/scripts/push_dataset_to_hub.py \ python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \ --data-dir data \
--dataset-id aloha_ping_ping \ --dataset-id aloha_static_pingpong_test \
--raw-format aloha_hdf5 \ --raw-format aloha_hdf5 \
--community-id lerobot --community-id lerobot
``` ```

View File

@ -0,0 +1,40 @@
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
# Configure image
ARG PYTHON_VERSION=3.10
ARG DEBIAN_FRONTEND=noninteractive
# Install apt dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \
git git-lfs openssh-client \
nano vim less util-linux \
htop atop nvtop \
sed gawk grep curl wget \
tcpdump sysstat screen tmux \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# Install gh cli tool
RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \
&& mkdir -p -m 755 /etc/apt/keyrings \
&& wget -qO- https://cli.github.com/packages/githubcli-archive-keyring.gpg | tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
&& apt update \
&& apt install gh -y \
&& apt clean && rm -rf /var/lib/apt/lists/*
# Setup `python`
RUN ln -s /usr/bin/python3 /usr/bin/python
# Install poetry
RUN curl -sSL https://install.python-poetry.org | python -
ENV PATH="/root/.local/bin:$PATH"
RUN echo 'if [ "$HOME" != "/root" ]; then ln -sf /root/.local/bin/poetry $HOME/.local/bin/poetry; fi' >> /root/.bashrc
RUN poetry config virtualenvs.create false
RUN poetry config virtualenvs.in-project true
# Set EGL as the rendering backend for MuJoCo
ENV MUJOCO_GL="egl"

View File

@ -4,6 +4,7 @@ FROM nvidia/cuda:12.4.1-base-ubuntu22.04
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.10
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# Install apt dependencies # Install apt dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \ build-essential cmake \
@ -11,6 +12,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \ python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get clean && rm -rf /var/lib/apt/lists/*
# Create virtual environment # Create virtual environment
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
RUN python -m venv /opt/venv RUN python -m venv /opt/venv

View File

@ -0,0 +1,165 @@
This tutorial will explain the training script, how to use it, and particularly the use of Hydra to configure everything needed for the training run.
## The training script
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
- Loads a Hydra configuration file for the following steps (more on Hydra in a moment).
- Makes a simulation environment.
- Makes a dataset corresponding to that simulation environment.
- Makes a policy.
- Runs a standard training loop with forward pass, backward pass, optimization step, and occasional logging, evaluation (of the policy on the environment), and checkpointing.
## Our use of Hydra
Explaining the ins and outs of [Hydra](https://hydra.cc/docs/intro/) is beyond the scope of this document, but here we'll share the main points you need to know.
First, `lerobot/configs` has a directory structure like this:
```
.
├── default.yaml
├── env
│ ├── aloha.yaml
│ ├── pusht.yaml
│ └── xarm.yaml
└── policy
├── act.yaml
├── diffusion.yaml
└── tdmpc.yaml
```
**_For brevity, in the rest of this document we'll drop the leading `lerobot/configs` path. So `default.yaml` really refers to `lerobot/configs/default.yaml`._**
When you run the training script with
```python
python lerobot/scripts/train.py
```
Hydra is set up to read `default.yaml` (via the `@hydra.main` decorator). If you take a look at the `@hydra.main`'s arguments you will see `config_path="../configs", config_name="default"`. At the top of `default.yaml`, is a `defaults` section which looks likes this:
```yaml
defaults:
- _self_
- env: pusht
- policy: diffusion
```
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overriden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`.
Thanks to this `defaults` section in `default.yaml`, if you want to train Diffusion Policy with PushT, you really only need to run:
```bash
python lerobot/scripts/train.py
```
However, you can be more explicit and launch the exact same Diffusion Policy training on PushT with:
```bash
python lerobot/scripts/train.py policy=diffusion env=pusht
```
This way of overriding defaults via the CLI is especially useful when you want to change the policy and/or environment. For instance, you can train ACT on the default Aloha environment with:
```bash
python lerobot/scripts/train.py policy=act env=aloha
```
There are two things to note here:
- Config overrides are passed as `param_name=param_value`.
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/pusht.yaml`.
_As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._
## Overriding configuration parameters in the CLI
Now let's say that we want to train on a different task in the Aloha environment. If you look in `env/aloha.yaml` you will see something like:
```yaml
# lerobot/configs/env/aloha.yaml
env:
task: AlohaInsertion-v0
```
And if you look in `policy/act.yaml` you will see something like:
```yaml
# lerobot/configs/policy/act.yaml
dataset_repo_id: lerobot/aloha_sim_insertion_human
```
But our Aloha environment actually supports a cube transfer task as well. To train for this task, you could manually modify the two yaml configuration files respectively.
First, we'd need to switch to using the cube transfer task for the ALOHA environment.
```diff
# lerobot/configs/env/aloha.yaml
env:
- task: AlohaInsertion-v0
+ task: AlohaTransferCube-v0
```
Then, we'd also need to switch to using the cube transfer dataset.
```diff
# lerobot/configs/policy/act.yaml
-dataset_repo_id: lerobot/aloha_sim_insertion_human
+dataset_repo_id: lerobot/aloha_sim_transfer_cube_human
```
Then, you'd be able to run:
```bash
python lerobot/scripts/train.py policy=act env=aloha
```
and you'd be training and evaluating on the cube transfer task.
An alternative approach to editing the yaml configuration files, would be to override the defaults via the command line:
```bash
python lerobot/scripts/train.py \
policy=act \
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
env=aloha \
env.task=AlohaTransferCube-v0
```
There's something new here. Notice the `.` delimiter used to traverse the configuration hierarchy. _But be aware that the `defaults` section is an exception. As you saw above, we didn't need to write `defaults.policy=act` in the CLI. `policy=act` was enough._
Putting all that knowledge together, here's the command that was used to train https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human.
```bash
python lerobot/scripts/train.py \
hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human \
device=cuda
env=aloha \
env.task=AlohaTransferCube-v0 \
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
policy=act \
training.eval_freq=10000 \
training.log_freq=250 \
training.offline_steps=100000 \
training.save_model=true \
training.save_freq=25000 \
eval.n_episodes=50 \
eval.batch_size=50 \
wandb.enable=false \
```
There's one new thing here: `hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human`, which specifies where to save the training output.
---
So far we've seen how to train Diffusion Policy for PushT and ACT for ALOHA. Now, what if we want to train ACT for PushT? Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
```bash
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
```
Please, head on over to our [advanced tutorial on adapting policy configuration to various environments](./advanced/train_act_pusht/train_act_pusht.md) to learn more.
Or in the meantime, happy coding! 🤗

View File

@ -0,0 +1,87 @@
# @package _global_
# Change the seed to match what PushT eval uses
# (to avoid evaluating on seeds used for generating the training data).
seed: 100000
# Change the dataset repository to the PushT one.
dataset_repo_id: lerobot/pusht
override_dataset_stats:
observation.image:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 80000
online_steps: 0
eval_freq: 10000
save_freq: 100000
log_freq: 250
save_model: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes: 50
batch_size: 50
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
observation.image: [3, 96, 96]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.image: mean_std
# Use min_max normalization just because it's more standard.
observation.state: min_max
output_normalization_modes:
# Use min_max normalization just because it's more standard.
action: min_max
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@ -0,0 +1,70 @@
In this tutorial we will learn how to adapt a policy configuration to be compatible with a new environment and dataset. As a concrete example, we will adapt the default configuration for ACT to be compatible with the PushT environment and dataset.
If you haven't already read our tutorial on the [training script and configuration tooling](../4_train_policy_with_script.md) please do so prior to tackling this tutorial.
Let's get started!
Suppose we want to train ACT for PushT. Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
```bash
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
```
We need to adapt the parameters of the ACT policy configuration to the PushT environment. The most important ones are the image keys.
ALOHA's datasets and environments typically use a variable number of cameras. In `lerobot/configs/policy/act.yaml` you may notice two relevant sections. Here we show you the minimal diff needed to adjust to PushT:
```diff
override_dataset_stats:
- observation.images.top:
+ observation.image:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
policy:
input_shapes:
- observation.images.top: [3, 480, 640]
+ observation.image: [3, 96, 96]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
input_normalization_modes:
- observation.images.top: mean_std
+ observation.image: mean_std
observation.state: min_max
output_normalization_modes:
action: min_max
```
Here we've accounted for the following:
- PushT uses "observation.image" for its image key.
- PushT provides smaller images.
_Side note: technically we could override these via the CLI, but with many changes it gets a bit messy, and we also have a bit of a challenge in that we're using `.` in our observation keys which is treated by Hydra as a hierarchical separator_.
For your convenience, we provide [`act_pusht.yaml`](./act_pusht.yaml) in this directory. It contains the diff above, plus some other (optional) ones that are explained within. Please copy it into `lerobot/configs/policy` with:
```bash
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/act_pusht.yaml
```
(remember from a [previous tutorial](../4_train_policy_with_script.md) that Hydra will look in the `lerobot/configs` directory). Now try running the following.
<!-- Note to contributor: are you changing this command? Note that it's tested in `Makefile`, so change it there too! -->
```bash
python lerobot/scripts/train.py policy=act_pusht env=pusht
```
Notice that this is much the same as the command that failed at the start of the tutorial, only:
- Now we are using `policy=act_pusht` to point to our new configuration file.
- We can drop `dataset_repo_id=lerobot/pusht` as the change is incorporated in our new configuration file.
Hurrah! You're now training ACT for the PushT environment.
---
The bottom line of this tutorial is that when training policies for different environments and datasets you will need to understand what parts of the policy configuration are specific to those and make changes accordingly.
Happy coding! 🤗

View File

@ -0,0 +1,90 @@
"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
This technique can be useful for debugging and testing purposes, as well as identifying whether a policy
is learning effectively.
Furthermore, relying on validation loss to evaluate performance is generally not considered a good practice,
especially in the context of imitation learning. The most reliable approach is to evaluate the policy directly
on the target environment, whether that be in simulation or the real world.
"""
import math
from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
device = torch.device("cuda")
# Download the diffusion policy for pusht environment
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
policy.eval()
policy.to(device)
# Set up the dataset.
delta_timestamps = {
# Load the previous image and state at -0.1 seconds before current frame,
# then load current image and state corresponding to 0.0 second.
"observation.image": [-0.1, 0.0],
"observation.state": [-0.1, 0.0],
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to calculate the loss.
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
# Load the last 10% of episodes of the dataset as a validation set.
# - Load full dataset
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
# - Calculate train and val subsets
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
num_val_episodes = full_dataset.num_episodes - num_train_episodes
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
# - Get first frame index of the validation set
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
# - Load frames subset belonging to validation set using the `split` argument.
# It utilizes the `datasets` library's syntax for slicing datasets.
# For more information on the Slice API, please see:
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
train_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
)
val_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
)
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
# Create dataloader for evaluation.
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
num_workers=4,
batch_size=64,
shuffle=False,
pin_memory=device != torch.device("cpu"),
drop_last=False,
)
# Run validation loop.
loss_cumsum = 0
n_examples_evaluated = 0
for batch in val_dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
output_dict = policy.forward(batch)
loss_cumsum += output_dict["loss"].item()
n_examples_evaluated += batch["index"].shape[0]
# Calculate the average loss over the validation set.
average_loss = loss_cumsum / n_examples_evaluated
print(f"Average loss on validation set: {average_loss:.4f}")

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library. This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables. We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
@ -46,13 +61,21 @@ available_datasets_per_env = {
"lerobot/aloha_sim_insertion_scripted", "lerobot/aloha_sim_insertion_scripted",
"lerobot/aloha_sim_transfer_cube_human", "lerobot/aloha_sim_transfer_cube_human",
"lerobot/aloha_sim_transfer_cube_scripted", "lerobot/aloha_sim_transfer_cube_scripted",
"lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_insertion_scripted_image",
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_transfer_cube_scripted_image",
], ],
"pusht": ["lerobot/pusht"], "pusht": ["lerobot/pusht", "lerobot/pusht_image"],
"xarm": [ "xarm": [
"lerobot/xarm_lift_medium", "lerobot/xarm_lift_medium",
"lerobot/xarm_lift_medium_replay", "lerobot/xarm_lift_medium_replay",
"lerobot/xarm_push_medium", "lerobot/xarm_push_medium",
"lerobot/xarm_push_medium_replay", "lerobot/xarm_push_medium_replay",
"lerobot/xarm_lift_medium_image",
"lerobot/xarm_lift_medium_replay_image",
"lerobot/xarm_push_medium_image",
"lerobot/xarm_push_medium_replay_image",
], ],
} }

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""To enable `lerobot.__version__`""" """To enable `lerobot.__version__`"""
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import random import random
import shutil import shutil

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import torch import torch

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from pathlib import Path from pathlib import Path
@ -5,17 +20,19 @@ import datasets
import torch import torch
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
calculate_episode_data_index,
load_episode_data_index, load_episode_data_index,
load_hf_dataset, load_hf_dataset,
load_info, load_info,
load_previous_and_future_frames, load_previous_and_future_frames,
load_stats, load_stats,
load_videos, load_videos,
reset_episode_index,
) )
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
CODEBASE_VERSION = "v1.3" CODEBASE_VERSION = "v1.4"
class LeRobotDataset(torch.utils.data.Dataset): class LeRobotDataset(torch.utils.data.Dataset):
@ -39,7 +56,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(rcadene, aliberts): implement faster transfer # TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
self.hf_dataset = load_hf_dataset(repo_id, version, root, split) self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
self.episode_data_index = load_episode_data_index(repo_id, version, root) if split == "train":
self.episode_data_index = load_episode_data_index(repo_id, version, root)
else:
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
self.hf_dataset = reset_episode_index(self.hf_dataset)
self.stats = load_stats(repo_id, version, root) self.stats = load_stats(repo_id, version, root)
self.info = load_info(repo_id, version, root) self.info = load_info(repo_id, version, root)
if self.video: if self.video:

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/) """Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script. Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This file contains all obsolete download scripts. They are centralized here to not have to load This file contains all obsolete download scripts. They are centralized here to not have to load
useless dependencies when using datasets. useless dependencies when using datasets.
@ -9,17 +24,16 @@ import shutil
from pathlib import Path from pathlib import Path
import tqdm import tqdm
from huggingface_hub import snapshot_download
ALOHA_RAW_URLS_DIR = "lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls"
def download_raw(raw_dir, dataset_id): def download_raw(raw_dir, dataset_id):
if "pusht" in dataset_id: if "aloha" in dataset_id or "image" in dataset_id:
download_hub(raw_dir, dataset_id)
elif "pusht" in dataset_id:
download_pusht(raw_dir) download_pusht(raw_dir)
elif "xarm" in dataset_id: elif "xarm" in dataset_id:
download_xarm(raw_dir) download_xarm(raw_dir)
elif "aloha" in dataset_id:
download_aloha(raw_dir, dataset_id)
elif "umi" in dataset_id: elif "umi" in dataset_id:
download_umi(raw_dir) download_umi(raw_dir)
else: else:
@ -88,37 +102,13 @@ def download_xarm(raw_dir: Path):
zip_path.unlink() zip_path.unlink()
def download_aloha(raw_dir: Path, dataset_id: str): def download_hub(raw_dir: Path, dataset_id: str):
import gdown
subset_id = dataset_id.replace("aloha_", "")
urls_path = Path(ALOHA_RAW_URLS_DIR) / f"{subset_id}.txt"
assert urls_path.exists(), f"{subset_id}.txt not found in '{ALOHA_RAW_URLS_DIR}' directory."
with open(urls_path) as f:
# strip lines and ignore empty lines
urls = [url.strip() for url in f if url.strip()]
# sanity check
for url in urls:
assert (
"drive.google.com/drive/folders" in url or "drive.google.com/file" in url
), f"Wrong url provided '{url}' in file '{urls_path}'."
raw_dir = Path(raw_dir) raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True) raw_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Start downloading from google drive for {dataset_id}") logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}")
for url in urls: snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir)
if "drive.google.com/drive/folders" in url: logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}")
# when a folder url is given, download up to 50 files from the folder
gdown.download_folder(url, output=str(raw_dir), remaining_ok=True)
elif "drive.google.com/file" in url:
# because of the 50 files limit per folder, we download the remaining files (file by file)
gdown.download(url, output=str(raw_dir), fuzzy=True)
logging.info(f"End downloading from google drive for {dataset_id}")
def download_umi(raw_dir: Path): def download_umi(raw_dir: Path):
@ -133,21 +123,30 @@ def download_umi(raw_dir: Path):
if __name__ == "__main__": if __name__ == "__main__":
data_dir = Path("data") data_dir = Path("data")
dataset_ids = [ dataset_ids = [
"pusht_image",
"xarm_lift_medium_image",
"xarm_lift_medium_replay_image",
"xarm_push_medium_image",
"xarm_push_medium_replay_image",
"aloha_sim_insertion_human_image",
"aloha_sim_insertion_scripted_image",
"aloha_sim_transfer_cube_human_image",
"aloha_sim_transfer_cube_scripted_image",
"pusht", "pusht",
"xarm_lift_medium", "xarm_lift_medium",
"xarm_lift_medium_replay", "xarm_lift_medium_replay",
"xarm_push_medium", "xarm_push_medium",
"xarm_push_medium_replay", "xarm_push_medium_replay",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
"aloha_mobile_cabinet", "aloha_mobile_cabinet",
"aloha_mobile_chair", "aloha_mobile_chair",
"aloha_mobile_elevator", "aloha_mobile_elevator",
"aloha_mobile_shrimp", "aloha_mobile_shrimp",
"aloha_mobile_wash_pan", "aloha_mobile_wash_pan",
"aloha_mobile_wipe_wine", "aloha_mobile_wipe_wine",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
"aloha_static_battery", "aloha_static_battery",
"aloha_static_candy", "aloha_static_candy",
"aloha_static_coffee", "aloha_static_coffee",

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# imagecodecs/numcodecs.py # imagecodecs/numcodecs.py
# Copyright (c) 2021-2022, Christoph Gohlke # Copyright (c) 2021-2022, Christoph Gohlke

View File

@ -1,8 +1,23 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
""" """
import re import gc
import shutil import shutil
from pathlib import Path from pathlib import Path
@ -64,10 +79,8 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
episode_data_index = {"from": [], "to": []} episode_data_index = {"from": [], "to": []}
id_from = 0 id_from = 0
for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)):
for ep_path in tqdm.tqdm(hdf5_files, total=len(hdf5_files)):
with h5py.File(ep_path, "r") as ep: with h5py.File(ep_path, "r") as ep:
ep_idx = int(re.search(r"episode_(\d+)", ep_path.name).group(1))
num_frames = ep["/action"].shape[0] num_frames = ep["/action"].shape[0]
# last step of demonstration is considered done # last step of demonstration is considered done
@ -76,6 +89,10 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
state = torch.from_numpy(ep["/observations/qpos"][:]) state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:]) action = torch.from_numpy(ep["/action"][:])
if "/observations/qvel" in ep:
velocity = torch.from_numpy(ep["/observations/qvel"][:])
if "/observations/effort" in ep:
effort = torch.from_numpy(ep["/observations/effort"][:])
ep_dict = {} ep_dict = {}
@ -116,6 +133,10 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = state ep_dict["observation.state"] = state
if "/observations/velocity" in ep:
ep_dict["observation.velocity"] = velocity
if "/observations/effort" in ep:
ep_dict["observation.effort"] = effort
ep_dict["action"] = action ep_dict["action"] = action
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames) ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1) ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
@ -131,6 +152,8 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
id_from += num_frames id_from += num_frames
gc.collect()
# process first episode only # process first episode only
if debug: if debug:
break break
@ -152,6 +175,14 @@ def to_hf_dataset(data_dict, video) -> Dataset:
features["observation.state"] = Sequence( features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
) )
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
)
features["action"] = Sequence( features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
) )

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy from copy import deepcopy
from math import ceil from math import ceil

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy""" """Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
import shutil import shutil

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface""" """Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface"""
import logging import logging

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm""" """Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
import pickle import pickle

View File

@ -1,5 +1,22 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import re
from pathlib import Path from pathlib import Path
from typing import Dict
import datasets import datasets
import torch import torch
@ -64,7 +81,23 @@ def hf_transform_to_torch(items_dict):
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc.""" """hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None: if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / split)) hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
# TODO(rcadene): clean this which enables getting a subset of dataset
if split != "train":
if "%" in split:
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
match_from = re.search(r"train\[(\d+):\]", split)
match_to = re.search(r"train\[:(\d+)\]", split)
if match_from:
from_frame_index = int(match_from.group(1))
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
elif match_to:
to_frame_index = int(match_to.group(1))
hf_dataset = hf_dataset.select(range(to_frame_index))
else:
raise ValueError(
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
)
else: else:
hf_dataset = load_dataset(repo_id, revision=version, split=split) hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)
@ -230,6 +263,84 @@ def load_previous_and_future_frames(
return item return item
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
Parameters:
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
Returns:
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
- "from": A tensor containing the starting index of each episode.
- "to": A tensor containing the ending index of each episode.
"""
episode_data_index = {"from": [], "to": []}
current_episode = None
"""
The episode_index is a list of integers, each representing the episode index of the corresponding example.
For instance, the following is a valid episode_index:
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
{
"from": [0, 3, 7],
"to": [3, 7, 12]
}
"""
if len(hf_dataset) == 0:
episode_data_index = {
"from": torch.tensor([]),
"to": torch.tensor([]),
}
return episode_data_index
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list
episode_data_index["from"].append(idx)
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
if current_episode is not None:
episode_data_index["to"].append(idx)
# Let's keep track of the current episode index
current_episode = episode_idx
else:
# We are still in the same episode, so there is nothing for us to do here
pass
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
episode_data_index["to"].append(idx + 1)
for k in ["from", "to"]:
episode_data_index[k] = torch.tensor(episode_data_index[k])
return episode_data_index
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
"""
Reset the `episode_index` of the provided HuggingFace Dataset.
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
This brings the `episode_index` to the required format.
"""
if len(hf_dataset) == 0:
return hf_dataset
unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist()
episode_idx_to_reset_idx_mapping = {
ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs)
}
def modify_ep_idx_func(example):
example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()]
return example
hf_dataset = hf_dataset.map(modify_ep_idx_func)
return hf_dataset
def cycle(iterable): def cycle(iterable):
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders. """The equivalent of itertools.cycle, but safe for Pytorch dataloaders.

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import subprocess import subprocess
import warnings import warnings

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib import importlib
import gymnasium as gym import gymnasium as gym

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import einops import einops
import numpy as np import numpy as np
import torch import torch

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(rcadene, alexander-soare): clean this file # TODO(rcadene, alexander-soare): clean this file
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py""" """Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
@ -17,9 +32,10 @@ def log_output_dir(out_dir):
def cfg_to_group(cfg, return_list=False): def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list.""" """Return a group name for logging. Optionally returns group name as list."""
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
lst = [ lst = [
f"policy:{cfg.policy.name}",
f"dataset:{cfg.dataset_repo_id}",
f"env:{cfg.env.name}", f"env:{cfg.env.name}",
f"seed:{cfg.seed}", f"seed:{cfg.seed}",
] ]
@ -81,9 +97,9 @@ class Logger:
# Also save the full Hydra config for the env configuration. # Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml") OmegaConf.save(self._cfg, save_dir / "config.yaml")
if self._wandb and not self._disable_wandb_artifact: if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name # note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact( artifact = self._wandb.Artifact(
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier), f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
type="model", type="model",
) )
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE) artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
@ -93,9 +109,10 @@ class Logger:
self._buffer_dir.mkdir(parents=True, exist_ok=True) self._buffer_dir.mkdir(parents=True, exist_ok=True)
fp = self._buffer_dir / f"{str(identifier)}.pkl" fp = self._buffer_dir / f"{str(identifier)}.pkl"
buffer.save(fp) buffer.save(fp)
if self._wandb: if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact( artifact = self._wandb.Artifact(
self._group + "-" + str(self._seed) + "-" + str(identifier), f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
type="buffer", type="buffer",
) )
artifact.add_file(fp) artifact.add_file(fp)
@ -113,6 +130,11 @@ class Logger:
assert mode in {"train", "eval"} assert mode in {"train", "eval"}
if self._wandb is not None: if self._wandb is not None:
for k, v in d.items(): for k, v in d.items():
if not isinstance(v, (int, float, str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step) self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"): def log_video(self, video_path: str, step: int, mode: str = "train"):

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -51,8 +66,12 @@ class ACTConfig:
documentation in the policy class). documentation in the policy class).
latent_dim: The VAE's latent dimension. latent_dim: The VAE's latent dimension.
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
environment step. actions for a given time step over multiple policy invocations. Updates are calculated as:
x = αx + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
parameter here: they refer to a weighting scheme wᵢ = exp(-mi) and set m = 0.01. With our
formulation, this is equivalent to α = exp(-0.01) 0.99. When this parameter is provided, we
require `n_action_steps == 1` (since we need to query the policy every step anyway).
dropout: Dropout to use in the transformer layers (see code for details). dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
@ -100,6 +119,9 @@ class ACTConfig:
dim_feedforward: int = 3200 dim_feedforward: int = 3200
feedforward_activation: str = "relu" feedforward_activation: str = "relu"
n_encoder_layers: int = 4 n_encoder_layers: int = 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: int = 1 n_decoder_layers: int = 1
# VAE. # VAE.
use_vae: bool = True use_vae: bool = True
@ -107,7 +129,7 @@ class ACTConfig:
n_vae_encoder_layers: int = 4 n_vae_encoder_layers: int = 4
# Inference. # Inference.
use_temporal_aggregation: bool = False temporal_ensemble_momentum: float | None = None
# Training and loss computation. # Training and loss computation.
dropout: float = 0.1 dropout: float = 0.1
@ -119,8 +141,11 @@ class ACTConfig:
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
) )
if self.use_temporal_aggregation: if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
raise NotImplementedError("Temporal aggregation is not yet implemented.") raise NotImplementedError(
"`n_action_steps` must be 1 when using temporal ensembling. This is "
"because the policy needs to be queried every step to compute the ensembled action."
)
if self.n_action_steps > self.chunk_size: if self.n_action_steps > self.chunk_size:
raise ValueError( raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got " f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
@ -130,10 +155,3 @@ class ACTConfig:
raise ValueError( raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
) )
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
if (
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
or "observation.images.top" not in self.input_shapes
):
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action Chunking Transformer Policy """Action Chunking Transformer Policy
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
@ -46,7 +61,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
super().__init__() super().__init__()
if config is None: if config is None:
config = ACTConfig() config = ACTConfig()
self.config = config self.config: ACTConfig = config
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats config.input_shapes, config.input_normalization_modes, dataset_stats
) )
@ -56,11 +72,18 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.unnormalize_outputs = Unnormalize( self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_shapes, config.output_normalization_modes, dataset_stats
) )
self.model = ACT(config) self.model = ACT(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.reset()
def reset(self): def reset(self):
"""This should be called whenever the environment is reset.""" """This should be called whenever the environment is reset."""
if self.config.n_action_steps is not None: if self.config.temporal_ensemble_momentum is not None:
self._ensembled_actions = None
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps) self._action_queue = deque([], maxlen=self.config.n_action_steps)
@torch.no_grad @torch.no_grad
@ -71,37 +94,56 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
environment. It works by managing the actions in a queue and only calling `select_actions` when the environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty. queue is empty.
""" """
assert "observation.images.top" in batch
assert "observation.state" in batch
self.eval() self.eval()
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
self._stack_images(batch) batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
# the first action.
if self.config.temporal_ensemble_momentum is not None:
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
if self._ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self._ensembled_actions = actions.clone()
else:
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the EMA update for those entries.
alpha = self.config.temporal_ensemble_momentum
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
# The last action, which has no prior moving average, needs to get concatenated onto the end.
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
# "Consume" the first action.
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
return action
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0: if len(self._action_queue) == 0:
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue actions = self.model(batch)[0][:, : self.config.n_action_steps]
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
actions = self.model(batch)[0][: self.config.n_action_steps]
# TODO(rcadene): make _forward return output dictionary? # TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1)) self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft() return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
self._stack_images(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = ( l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean() ).mean()
loss_dict = {"l1_loss": l1_loss} loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae: if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total # each dimension independently, we sum over the latent dimension to get the total
@ -110,28 +152,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
mean_kld = ( mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
) )
loss_dict["kld_loss"] = mean_kld loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
else: else:
loss_dict["loss"] = l1_loss loss_dict["loss"] = l1_loss
return loss_dict return loss_dict
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
This function expects `batch` to have (at least):
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
# Stack images in the order dictated by input_shapes.
batch["observation.images"] = torch.stack(
[batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")],
dim=-4,
)
class ACT(nn.Module): class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy. """Action Chunking Transformer: The underlying neural network for ACTPolicy.
@ -161,10 +188,10 @@ class ACT(nn.Module):
encoder Transf. encoder Transf.
encoder encoder
inputs inputs image emb.
state emb.
""" """
@ -306,18 +333,18 @@ class ACT(nn.Module):
all_cam_features.append(cam_features) all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed) all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension. # Concatenate camera observation feature maps and positional embeddings along the width dimension.
encoder_in = torch.cat(all_cam_features, axis=3) encoder_in = torch.cat(all_cam_features, axis=-1)
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3) cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent. # Get positional embeddings for robot state and latent.
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C). # Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in = torch.cat( encoder_in = torch.cat(
[ [
torch.stack([latent_embed, robot_state_embed], axis=0), torch.stack([latent_embed, robot_state_embed], axis=0),
encoder_in.flatten(2).permute(2, 0, 1), einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
] ]
) )
pos_embed = torch.cat( pos_embed = torch.cat(

View File

@ -1,3 +1,19 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -51,6 +67,7 @@ class DiffusionConfig:
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning. use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
Bias modulation is used be default, while this parameter indicates whether to also use scale Bias modulation is used be default, while this parameter indicates whether to also use scale
modulation. modulation.
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule. num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers. beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
beta_start: Beta value for the first forward-diffusion step. beta_start: Beta value for the first forward-diffusion step.
@ -110,6 +127,7 @@ class DiffusionConfig:
diffusion_step_embed_dim: int = 128 diffusion_step_embed_dim: int = 128
use_film_scale_modulation: bool = True use_film_scale_modulation: bool = True
# Noise scheduler. # Noise scheduler.
noise_scheduler_type: str = "DDPM"
num_train_timesteps: int = 100 num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2" beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001 beta_start: float = 0.0001
@ -130,17 +148,30 @@ class DiffusionConfig:
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
) )
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if ( if (
self.crop_shape[0] > self.input_shapes["observation.image"][1] self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes["observation.image"][2] or self.crop_shape[1] > self.input_shapes[image_key][2]
): ):
raise ValueError( raise ValueError(
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} ' f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f'for `crop_shape` and {self.input_shapes["observation.image"]} for ' f"for `crop_shape` and {self.input_shapes[image_key]} for "
'`input_shapes["observation.image"]`.' "`input_shapes[{image_key}]`."
) )
supported_prediction_types = ["epsilon", "sample"] supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types: if self.prediction_type not in supported_prediction_types:
raise ValueError( raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}." f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
) )
supported_noise_schedulers = ["DDPM", "DDIM"]
if self.noise_scheduler_type not in supported_noise_schedulers:
raise ValueError(
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
f"Got {self.noise_scheduler_type}."
)

View File

@ -1,8 +1,24 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" """Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
TODO(alexander-soare): TODO(alexander-soare):
- Remove reliance on Robomimic for SpatialSoftmax.
- Remove reliance on diffusers for DDPMScheduler and LR scheduler. - Remove reliance on diffusers for DDPMScheduler and LR scheduler.
- Make compatible with multiple image keys.
""" """
import math import math
@ -10,12 +26,13 @@ from collections import deque
from typing import Callable from typing import Callable
import einops import einops
import numpy as np
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
import torchvision import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn from torch import Tensor, nn
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
@ -66,10 +83,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
self.diffusion = DiffusionModel(config) self.diffusion = DiffusionModel(config)
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
if len(image_keys) != 1:
raise NotImplementedError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
self.input_image_key = image_keys[0]
self.reset()
def reset(self): def reset(self):
""" """Clear observation and action queues. Should be called on `env.reset()`"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = { self._queues = {
"observation.image": deque(maxlen=self.config.n_obs_steps), "observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps),
@ -98,16 +123,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"horizon" may not the best name to describe what the variable actually means, because this period is "horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
""" """
assert "observation.image" in batch
assert "observation.state" in batch
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0: if len(self._queues["action"]) == 0:
# stack n latest observations from the queue # stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch) actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary? # TODO(rcadene): make above methods return output dictionary?
@ -121,11 +144,25 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch) loss = self.diffusion.compute_loss(batch)
return {"loss": loss} return {"loss": loss}
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
"""
Factory for noise scheduler instances of the requested type. All kwargs are passed
to the scheduler.
"""
if name == "DDPM":
return DDPMScheduler(**kwargs)
elif name == "DDIM":
return DDIMScheduler(**kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {name}")
class DiffusionModel(nn.Module): class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig): def __init__(self, config: DiffusionConfig):
super().__init__() super().__init__()
@ -138,12 +175,12 @@ class DiffusionModel(nn.Module):
* config.n_obs_steps, * config.n_obs_steps,
) )
self.noise_scheduler = DDPMScheduler( self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps, num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start, beta_start=config.beta_start,
beta_end=config.beta_end, beta_end=config.beta_end,
beta_schedule=config.beta_schedule, beta_schedule=config.beta_schedule,
variance_type="fixed_small",
clip_sample=config.clip_sample, clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range, clip_sample_range=config.clip_sample_range,
prediction_type=config.prediction_type, prediction_type=config.prediction_type,
@ -185,13 +222,12 @@ class DiffusionModel(nn.Module):
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
""" """
This function expects `batch` to have (at least): This function expects `batch` to have:
{ {
"observation.state": (B, n_obs_steps, state_dim) "observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W) "observation.image": (B, n_obs_steps, C, H, W)
} }
""" """
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2] batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps assert n_obs_steps == self.config.n_obs_steps
@ -275,6 +311,77 @@ class DiffusionModel(nn.Module):
return loss.mean() return loss.mean()
class SpatialSoftmax(nn.Module):
"""
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
-----------------------------------------------------
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
| ... | ... | ... | ... |
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
-----------------------------------------------------
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
linear mapping (in_channels, H, W) -> (num_kp, H, W).
"""
def __init__(self, input_shape, num_kp=None):
"""
Args:
input_shape (list): (C, H, W) input feature map shape.
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
"""
super().__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._out_c = num_kp
else:
self.nets = None
self._out_c = self._in_c
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
def forward(self, features: Tensor) -> Tensor:
"""
Args:
features: (B, C, H, W) input feature maps.
Returns:
(B, K, 2) image-space coordinates of keypoints.
"""
if self.nets is not None:
features = self.nets(features)
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
features = features.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization
attention = F.softmax(features, dim=-1)
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
expected_xy = attention @ self.pos_grid
# reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
return feature_keypoints
class DiffusionRgbEncoder(nn.Module): class DiffusionRgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector. """Encoder an RGB image into a 1D feature vector.
@ -315,11 +422,16 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers. # Set up pooling and final layers.
# Use a dry run to get the feature map shape. # Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# use the height and width from `config.crop_shape`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1
image_key = image_keys[0]
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
with torch.inference_mode(): with torch.inference_mode():
feat_map_shape = tuple( dummy_feature_map = self.backbone(dummy_input)
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:] feature_map_shape = tuple(dummy_feature_map.shape[1:])
) self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU() self.relu = nn.ReLU()

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect import inspect
import logging
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
@ -8,9 +24,10 @@ from lerobot.common.utils.utils import get_safe_torch_device
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters) expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
assert set(hydra_cfg.policy).issuperset( if not set(hydra_cfg.policy).issuperset(expected_kwargs):
expected_kwargs logging.warning(
), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
)
policy_cfg = policy_cfg_class( policy_cfg = policy_cfg_class(
**{ **{
k: v k: v
@ -62,11 +79,18 @@ def make_policy(
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name) policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
if pretrained_policy_name_or_path is None: if pretrained_policy_name_or_path is None:
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) # Make a fresh policy.
policy = policy_cls(policy_cfg, dataset_stats) policy = policy_cls(policy_cfg, dataset_stats)
else: else:
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path) # Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained
# weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should
# make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274.
policy = policy_cls(policy_cfg)
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
policy.to(get_safe_torch_device(hydra_cfg.device)) policy.to(get_safe_torch_device(hydra_cfg.device))

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A protocol that all policies should follow. """A protocol that all policies should follow.
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
@ -38,7 +53,8 @@ class Policy(Protocol):
def forward(self, batch: dict[str, Tensor]) -> dict: def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation. """Run the batch through the model and compute the loss for training or validation.
Returns a dictionary with "loss" and maybe other information. Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
""" """
def select_action(self, batch: dict[str, Tensor]): def select_action(self, batch: dict[str, Tensor]):

View File

@ -1,3 +1,19 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -47,7 +63,7 @@ class TDMPCConfig:
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
elites, when updating the gaussian parameters for CEM. elites, when updating the gaussian parameters for CEM.
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
paramters optimized in CEM. Updates are calculated as μ αμ + (1-α)μ. parameters optimized in CEM. Updates are calculated as μ αμ + (1-α)μ.
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
is applied. Note that the input images are assumed to be square for this augmentation. is applied. Note that the input images are assumed to be square for this augmentation.
@ -131,12 +147,18 @@ class TDMPCConfig:
def __post_init__(self): def __post_init__(self):
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]: # There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift # TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed. # augmentation. It should be able to be removed.
raise ValueError( raise ValueError(
"Only square images are handled now. Got image shape " f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
f"{self.input_shapes['observation.image']}."
) )
if self.n_gaussian_samples <= 0: if self.n_gaussian_samples <= 0:
raise ValueError( raise ValueError(

View File

@ -1,3 +1,19 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of Finetuning Offline World Models in the Real World. """Implementation of Finetuning Offline World Models in the Real World.
The comments in this code may sometimes refer to these references: The comments in this code may sometimes refer to these references:
@ -96,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_shapes, config.output_normalization_modes, dataset_stats
) )
def save(self, fp): image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
"""Save state dict of TOLD model to filepath.""" # Note: This check is covered in the post-init of the config but have a sanity check just in case.
torch.save(self.state_dict(), fp) assert len(image_keys) == 1
self.input_image_key = image_keys[0]
def load(self, fp): self.reset()
"""Load a saved state dict from filepath into current agent."""
self.load_state_dict(torch.load(fp))
def reset(self): def reset(self):
""" """
@ -121,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]): def select_action(self, batch: dict[str, Tensor]):
"""Select a single action given environment observations.""" """Select a single action given environment observations."""
assert "observation.image" in batch
assert "observation.state" in batch
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
@ -303,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
device = get_device_from_parameters(self) device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
info = {} info = {}
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
batch_size = batch["index"].shape[0]
# (b, t) -> (t, b) # (b, t) -> (t, b)
for key in batch: for key in batch:
if batch[key].ndim > 1: if batch[key].ndim > 1:
@ -337,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# Run latent rollout using the latent dynamics model and policy model. # Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`. # gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation) z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device) reward_preds = torch.empty_like(reward, device=device)

View File

@ -1,9 +1,28 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
from torch import nn from torch import nn
def populate_queues(queues, batch): def populate_queues(queues, batch):
for key in batch: for key in batch:
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
# queues have the keys they want).
if key not in queues:
continue
if len(queues[key]) != queues[key].maxlen: if len(queues[key]) != queues[key].maxlen:
# initialize by copying the first observation several times until the queue is full # initialize by copying the first observation several times until the queue is full
while len(queues[key]) != queues[key].maxlen: while len(queues[key]) != queues[key].maxlen:

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib import importlib
import logging import logging

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings import warnings
import imageio import imageio

View File

@ -1,8 +1,25 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import os.path as osp import os.path as osp
import random import random
from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Generator
import hydra import hydra
import numpy as np import numpy as np
@ -39,6 +56,31 @@ def set_global_seed(seed):
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:
"""Set the seed when entering a context, and restore the prior random state at exit.
Example usage:
```
a = random.random() # produces some random number
with seeded_context(1337):
b = random.random() # produces some other random number
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
```
"""
random_state = random.getstate()
np_random_state = np.random.get_state()
torch_random_state = torch.random.get_rng_state()
torch_cuda_random_state = torch.cuda.random.get_rng_state()
set_global_seed(seed)
yield None
random.setstate(random_state)
np.random.set_state(np_random_state)
torch.random.set_rng_state(torch_random_state)
torch.cuda.random.set_rng_state(torch_cuda_random_state)
def init_logging(): def init_logging():
def custom_format(record): def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View File

@ -10,6 +10,9 @@ hydra:
name: default name: default
device: cuda # cpu device: cuda # cpu
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: false
# `seed` is used for training (eg: model initialization, dataset shuffling) # `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments. # AND for the evaluation environments.
seed: ??? seed: ???
@ -17,6 +20,7 @@ dataset_repo_id: lerobot/pusht
training: training:
offline_steps: ??? offline_steps: ???
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
online_steps: ??? online_steps: ???
online_steps_between_rollouts: ??? online_steps_between_rollouts: ???
online_sampling_ratio: 0.5 online_sampling_ratio: 0.5
@ -35,7 +39,7 @@ eval:
use_async_envs: false use_async_envs: false
wandb: wandb:
enable: true enable: false
# Set to true to disable saving an artifact despite save_model == True # Set to true to disable saving an artifact despite save_model == True
disable_artifact: false disable_artifact: false
project: lerobot project: lerobot

View File

@ -3,6 +3,12 @@
seed: 1000 seed: 1000
dataset_repo_id: lerobot/aloha_sim_insertion_human dataset_repo_id: lerobot/aloha_sim_insertion_human
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training: training:
offline_steps: 80000 offline_steps: 80000
online_steps: 0 online_steps: 0
@ -18,12 +24,6 @@ training:
grad_clip_norm: 10 grad_clip_norm: 10
online_steps_between_rollouts: 1 online_steps_between_rollouts: 1
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
delta_timestamps: delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]" action: "[i / ${fps} for i in range(${policy.chunk_size})]"
@ -66,6 +66,9 @@ policy:
dim_feedforward: 3200 dim_feedforward: 3200
feedforward_activation: relu feedforward_activation: relu
n_encoder_layers: 4 n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1 n_decoder_layers: 1
# VAE. # VAE.
use_vae: true use_vae: true
@ -73,7 +76,7 @@ policy:
n_vae_encoder_layers: 4 n_vae_encoder_layers: 4
# Inference. # Inference.
use_temporal_aggregation: false temporal_ensemble_momentum: null
# Training and loss computation. # Training and loss computation.
dropout: 0.1 dropout: 0.1

View File

@ -7,6 +7,20 @@
seed: 100000 seed: 100000
dataset_repo_id: lerobot/pusht dataset_repo_id: lerobot/pusht
override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
# from the original codebase, but we should remove these and train our own pretrained model
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
action:
min: [12.0, 25.0]
max: [511.0, 511.0]
training: training:
offline_steps: 200000 offline_steps: 200000
online_steps: 0 online_steps: 0
@ -34,20 +48,6 @@ eval:
n_episodes: 50 n_episodes: 50
batch_size: 50 batch_size: 50
override_dataset_stats:
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
observation.image:
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
# from the original codebase, but we should remove these and train our own pretrained model
observation.state:
min: [13.456424, 32.938293]
max: [496.14618, 510.9579]
action:
min: [12.0, 25.0]
max: [511.0, 511.0]
policy: policy:
name: diffusion name: diffusion
@ -85,6 +85,7 @@ policy:
diffusion_step_embed_dim: 128 diffusion_step_embed_dim: 128
use_film_scale_modulation: True use_film_scale_modulation: True
# Noise scheduler. # Noise scheduler.
noise_scheduler_type: DDPM
num_train_timesteps: 100 num_train_timesteps: 100
beta_schedule: squaredcos_cap_v2 beta_schedule: squaredcos_cap_v2
beta_start: 0.0001 beta_start: 0.0001

View File

@ -1,11 +1,12 @@
# @package _global_ # @package _global_
seed: 1 seed: 1
dataset_repo_id: lerobot/xarm_lift_medium_replay dataset_repo_id: lerobot/xarm_lift_medium
training: training:
offline_steps: 25000 offline_steps: 25000
online_steps: 25000 # TODO(alexander-soare): uncomment when online training gets reinstated
online_steps: 0 # 25000 not implemented yet
eval_freq: 5000 eval_freq: 5000
online_steps_between_rollouts: 1 online_steps_between_rollouts: 1
online_sampling_ratio: 0.5 online_sampling_ratio: 0.5

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform import platform
import huggingface_hub import huggingface_hub

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluate a policy on an environment by running rollouts and computing metrics. """Evaluate a policy on an environment by running rollouts and computing metrics.
Usage examples: Usage examples:
@ -31,6 +46,7 @@ import json
import logging import logging
import threading import threading
import time import time
from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from datetime import datetime as dt from datetime import datetime as dt
from pathlib import Path from pathlib import Path
@ -505,7 +521,7 @@ def eval(
raise NotImplementedError() raise NotImplementedError()
# Check device is available # Check device is available
get_safe_torch_device(hydra_cfg.device, log=True) device = get_safe_torch_device(hydra_cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -524,16 +540,17 @@ def eval(
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
policy.eval() policy.eval()
info = eval_policy( with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
env, info = eval_policy(
policy, env,
hydra_cfg.eval.n_episodes, policy,
max_episodes_rendered=10, hydra_cfg.eval.n_episodes,
video_dir=Path(out_dir) / "eval", max_episodes_rendered=10,
start_seed=hydra_cfg.seed, video_dir=Path(out_dir) / "eval",
enable_progbar=True, start_seed=hydra_cfg.seed,
enable_inner_progbar=True, enable_progbar=True,
) enable_inner_progbar=True,
)
print(info["aggregated"]) print(info["aggregated"])
# Save info # Save info

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub, Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
@ -10,7 +25,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
--dataset-id pusht \ --dataset-id pusht \
--raw-format pusht_zarr \ --raw-format pusht_zarr \
--community-id lerobot \ --community-id lerobot \
--revision v1.2 \
--dry-run 1 \ --dry-run 1 \
--save-to-disk 1 \ --save-to-disk 1 \
--save-tests-to-disk 0 \ --save-tests-to-disk 0 \
@ -21,7 +35,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
--dataset-id xarm_lift_medium \ --dataset-id xarm_lift_medium \
--raw-format xarm_pkl \ --raw-format xarm_pkl \
--community-id lerobot \ --community-id lerobot \
--revision v1.2 \
--dry-run 1 \ --dry-run 1 \
--save-to-disk 1 \ --save-to-disk 1 \
--save-tests-to-disk 0 \ --save-tests-to-disk 0 \
@ -32,7 +45,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
--dataset-id aloha_sim_insertion_scripted \ --dataset-id aloha_sim_insertion_scripted \
--raw-format aloha_hdf5 \ --raw-format aloha_hdf5 \
--community-id lerobot \ --community-id lerobot \
--revision v1.2 \
--dry-run 1 \ --dry-run 1 \
--save-to-disk 1 \ --save-to-disk 1 \
--save-tests-to-disk 0 \ --save-tests-to-disk 0 \
@ -43,7 +55,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
--dataset-id umi_cup_in_the_wild \ --dataset-id umi_cup_in_the_wild \
--raw-format umi_zarr \ --raw-format umi_zarr \
--community-id lerobot \ --community-id lerobot \
--revision v1.2 \
--dry-run 1 \ --dry-run 1 \
--save-to-disk 1 \ --save-to-disk 1 \
--save-tests-to-disk 0 \ --save-tests-to-disk 0 \
@ -212,8 +223,7 @@ def push_dataset_to_hub(
test_hf_dataset = test_hf_dataset.with_format(None) test_hf_dataset = test_hf_dataset.with_format(None)
test_hf_dataset.save_to_disk(str(tests_out_dir / "train")) test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
# copy meta data to tests directory save_meta_data(info, stats, episode_data_index, tests_meta_data_dir)
shutil.copytree(meta_data_dir, tests_meta_data_dir)
# copy videos of first episode to tests directory # copy videos of first episode to tests directory
episode_index = 0 episode_index = 0
@ -222,6 +232,10 @@ def push_dataset_to_hub(
fname = f"{key}_episode_{episode_index:06d}.mp4" fname = f"{key}_episode_{episode_index:06d}.mp4"
shutil.copy(videos_dir / fname, tests_videos_dir / fname) shutil.copy(videos_dir / fname, tests_videos_dir / fname)
if not save_to_disk and out_dir.exists():
# remove possible temporary files remaining in the output directory
shutil.rmtree(out_dir)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -299,7 +313,7 @@ def main():
parser.add_argument( parser.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=16, default=8,
help="Number of processes of Dataloader for computing the dataset statistics.", help="Number of processes of Dataloader for computing the dataset statistics.",
) )
parser.add_argument( parser.add_argument(

View File

@ -1,13 +1,28 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import time import time
from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import datasets
import hydra import hydra
import torch import torch
from datasets import concatenate_datasets from omegaconf import DictConfig
from datasets.utils import disable_progress_bars, enable_progress_bars from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
@ -15,6 +30,7 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device, get_safe_torch_device,
@ -53,7 +69,6 @@ def make_optimizer_and_scheduler(cfg, policy):
cfg.training.adam_eps, cfg.training.adam_eps,
cfg.training.adam_weight_decay, cfg.training.adam_weight_decay,
) )
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
@ -71,20 +86,40 @@ def make_optimizer_and_scheduler(cfg, policy):
return optimizer, lr_scheduler return optimizer, lr_scheduler
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): def update_policy(
start_time = time.time() policy,
batch,
optimizer,
grad_clip_norm,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
):
"""Returns a dictionary of items for logging."""
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train() policy.train()
output_dict = policy.forward(batch) with torch.autocast(device_type=device.type) if use_amp else nullcontext():
# TODO(rcadene): policy.unnormalize_outputs(out_dict) output_dict = policy.forward(batch)
loss = output_dict["loss"] # TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss.backward() loss = output_dict["loss"]
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), policy.parameters(),
grad_clip_norm, grad_clip_norm,
error_if_nonfinite=False, error_if_nonfinite=False,
) )
optimizer.step() # Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if lr_scheduler is not None: if lr_scheduler is not None:
@ -98,7 +133,8 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"loss": loss.item(), "loss": loss.item(),
"grad_norm": float(grad_norm), "grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"], "lr": optimizer.param_groups[0]["lr"],
"update_s": time.time() - start_time, "update_s": time.perf_counter() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
} }
return info return info
@ -122,7 +158,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
train(cfg, out_dir=out_dir, job_name=job_name) train(cfg, out_dir=out_dir, job_name=job_name)
def log_train_info(logger, info, step, cfg, dataset, is_offline): def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
loss = info["loss"] loss = info["loss"]
grad_norm = info["grad_norm"] grad_norm = info["grad_norm"]
lr = info["lr"] lr = info["lr"]
@ -193,104 +229,7 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
logger.log_dict(info, step, mode="eval") logger.log_dict(info, step, mode="eval")
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float): def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
"""
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
Parameters:
- n_off (int): Number of offline samples, each with a sampling weight of 1.
- n_on (int): Number of online samples.
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
The total weight of offline samples is n_off * 1.0.
The total weight of offline samples is n_on * w.
The total combined weight of all samples is n_off + n_on * w.
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
"""
assert 0.0 <= pc_on <= 1.0
return -(n_off * pc_on) / (n_on * (pc_on - 1))
def add_episodes_inplace(
online_dataset: torch.utils.data.Dataset,
concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
pc_online_samples: float,
):
"""
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
new episodes from hf_dataset into the online_dataset, updating the concatenated
dataset's structure and adjusting the sampling strategy based on the specified
percentage of online samples.
Parameters:
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
offline and online datasets, used for sampling purposes.
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
reflect changes in the dataset sizes and specified sampling weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- pc_online_samples (float): The target percentage of samples that should come from
the online dataset during sampling operations.
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
"""
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
# sanity check
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
assert first_index == 0, f"{first_index=} is not 0"
assert first_index == episode_data_index["from"][first_episode_idx].item()
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.hf_dataset = hf_dataset
online_dataset.episode_data_index = episode_data_index
else:
# get the starting indices of the new episodes and frames to be added
start_episode_idx = last_episode_idx + 1
start_index = last_index + 1
def shift_indices(episode_index, index):
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
return example
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
enable_progress_bars()
episode_data_index["from"] += start_index
episode_data_index["to"] += start_index
# extend online dataset
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
# update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
len_online = len(online_dataset)
len_offline = len(concat_dataset) - len_online
weight_offline = 1.0
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
# update the total number of samples used during sampling
sampler.num_samples = len(concat_dataset)
def train(cfg: dict, out_dir=None, job_name=None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
if job_name is None: if job_name is None:
@ -298,11 +237,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
init_logging() init_logging()
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1: if cfg.training.online_steps > 0:
logging.warning("eval.batch_size > 1 not supported for online training steps") raise NotImplementedError("Online training is not implemented yet.")
# Check device is available # Check device is available
get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -320,6 +259,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# Create optimizer and scheduler # Create optimizer and scheduler
# Temporary hack to move optimizer out of policy # Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(enabled=cfg.use_amp)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
@ -340,14 +280,15 @@ def train(cfg: dict, out_dir=None, job_name=None):
def evaluate_and_checkpoint_if_needed(step): def evaluate_and_checkpoint_if_needed(step):
if step % cfg.training.eval_freq == 0: if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
eval_info = eval_policy( with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_env, eval_info = eval_policy(
policy, eval_env,
cfg.eval.n_episodes, policy,
video_dir=Path(out_dir) / "eval", cfg.eval.n_episodes,
max_episodes_rendered=4, video_dir=Path(out_dir) / "eval",
start_seed=cfg.seed, max_episodes_rendered=4,
) start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable: if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval") logger.log_video(eval_info["video_paths"][0], step, mode="eval")
@ -371,23 +312,30 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_workers=4, num_workers=4,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
shuffle=True, shuffle=True,
pin_memory=cfg.device != "cpu", pin_memory=device.type != "cpu",
drop_last=False, drop_last=False,
) )
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
policy.train() policy.train()
step = 0 # number of policy update (forward + backward + optim)
is_offline = True is_offline = True
for offline_step in range(cfg.training.offline_steps): for step in range(cfg.training.offline_steps):
if offline_step == 0: if step == 0:
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
batch = next(dl_iter) batch = next(dl_iter)
for key in batch: for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True) batch[key] = batch[key].to(device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) train_info = update_policy(
policy,
batch,
optimizer,
cfg.training.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.training.log_freq == 0: if step % cfg.training.log_freq == 0:
@ -397,11 +345,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
# so we pass in step + 1. # so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1) evaluate_and_checkpoint_if_needed(step + 1)
step += 1
# create an env dedicated to online episodes collection from policy rollout
online_training_env = make_env(cfg, n_envs=1)
# create an empty online dataset similar to offline dataset # create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset) online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {} online_dataset.hf_dataset = {}
@ -418,58 +361,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_workers=4, num_workers=4,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
sampler=sampler, sampler=sampler,
pin_memory=cfg.device != "cpu", pin_memory=device.type != "cpu",
drop_last=False, drop_last=False,
) )
dl_iter = cycle(dataloader)
online_step = 0
is_offline = False
for env_step in range(cfg.training.online_steps):
if env_step == 0:
logging.info("Start online training by interacting with environment")
policy.eval()
with torch.no_grad():
eval_info = eval_policy(
online_training_env,
policy,
n_episodes=1,
return_episode_data=True,
start_seed=cfg.training.online_env_seed,
enable_progbar=True,
)
add_episodes_inplace(
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.training.online_sampling_ratio,
)
policy.train()
for _ in range(cfg.training.online_steps_between_rollouts):
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1)
step += 1
online_step += 1
eval_env.close() eval_env.close()
online_training_env.close()
logging.info("End of training") logging.info("End of training")

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note: The last frame of the episode doesnt always correspond to a final state. Note: The last frame of the episode doesnt always correspond to a final state.
@ -47,6 +62,7 @@ local$ rerun ws://localhost:9087
""" """
import argparse import argparse
import gc
import logging import logging
import time import time
from pathlib import Path from pathlib import Path
@ -115,15 +131,17 @@ def visualize_dataset(
spawn_local_viewer = mode == "local" and not save spawn_local_viewer = mode == "local" and not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer) rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
# when iterating on a dataloader with `num_workers` > 0
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
gc.collect()
if mode == "distant": if mode == "distant":
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port) rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
logging.info("Logging to Rerun") logging.info("Logging to Rerun")
if num_workers > 0:
# TODO(rcadene): fix data workers hanging when `rr.init` is called
logging.warning("If data loader is hanging, try `--num-workers 0`.")
for batch in tqdm.tqdm(dataloader, total=len(dataloader)): for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
# iterate over the batch # iterate over the batch
for i in range(len(batch["index"])): for i in range(len(batch["index"])):
@ -196,7 +214,7 @@ def main():
parser.add_argument( parser.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=0, default=4,
help="Number of processes of Dataloader for loading the data.", help="Number of processes of Dataloader for loading the data.",
) )
parser.add_argument( parser.add_argument(

1294
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -28,37 +28,36 @@ packages = [{include = "lerobot"}]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.10,<3.13" python = ">=3.10,<3.13"
termcolor = "^2.4.0" termcolor = ">=2.4.0"
omegaconf = "^2.3.0" omegaconf = ">=2.3.0"
wandb = "^0.16.3" wandb = ">=0.16.3"
imageio = {extras = ["ffmpeg"], version = "^2.34.0"} imageio = {extras = ["ffmpeg"], version = ">=2.34.0"}
gdown = "^5.1.0" gdown = ">=5.1.0"
hydra-core = "^1.3.2" hydra-core = ">=1.3.2"
einops = "^0.8.0" einops = ">=0.8.0"
pymunk = "^6.6.0" pymunk = ">=6.6.0"
zarr = "^2.17.0" zarr = ">=2.17.0"
numba = "^0.59.0" numba = ">=0.59.0"
torch = "^2.2.1" torch = "^2.2.1"
opencv-python = "^4.9.0.80" opencv-python = ">=4.9.0"
diffusers = "^0.27.2" diffusers = "^0.27.2"
torchvision = "^0.18.0" torchvision = ">=0.17.1"
h5py = "^3.10.0" h5py = ">=3.10.0"
huggingface-hub = "^0.21.4" huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
robomimic = "0.2.0" gymnasium = ">=0.29.1"
gymnasium = "^0.29.1" cmake = ">=3.29.0.1"
cmake = "^3.29.0.1" gym-pusht = { version = ">=0.1.3", optional = true}
gym-pusht = { version = "^0.1.0", optional = true} gym-xarm = { version = ">=0.1.1", optional = true}
gym-xarm = { version = "^0.1.0", optional = true} gym-aloha = { version = ">=0.1.1", optional = true}
gym-aloha = { version = "^0.1.0", optional = true} pre-commit = {version = ">=3.7.0", optional = true}
pre-commit = {version = "^3.7.0", optional = true} debugpy = {version = ">=1.8.1", optional = true}
debugpy = {version = "^1.8.1", optional = true} pytest = {version = ">=8.1.0", optional = true}
pytest = {version = "^8.1.0", optional = true} pytest-cov = {version = ">=5.0.0", optional = true}
pytest-cov = {version = "^5.0.0", optional = true}
datasets = "^2.19.0" datasets = "^2.19.0"
imagecodecs = { version = "^2024.1.1", optional = true } imagecodecs = { version = ">=2024.1.1", optional = true }
pyav = "^12.0.5" pyav = ">=12.0.5"
moviepy = "^1.0.3" moviepy = ">=1.0.3"
rerun-sdk = "^0.15.1" rerun-sdk = ">=0.15.1"
[tool.poetry.extras] [tool.poetry.extras]
@ -104,5 +103,5 @@ ignore-init-module-imports = true
[build-system] [build-system]
requires = ["poetry-core>=1.5.0"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .utils import DEVICE from .utils import DEVICE

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9f9347c8d9ac90ee44e6dd86f65043438168df6bbe4bab2d2b875e55ef7376ef
size 1488

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
size 33

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:02fc4ea25766269f65752a60b0594c43d799b0ae528cd773bf024b064b5aa329
size 4344

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:55d7b1a06fe3e3051482752740074348bdb5fc98fb2e305b06d6203994117b27
size 592448

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8
size 1166

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:98329e4b40e9be0d63f7d36da9d86c44bbe7eeeb1b10d3ba973c923f3be70867
size 247

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:54e42cdfd016a0ced2ab1fe2966a8c15a2384e0dbe1a2fe87433a2d1b8209ac0
size 5220057

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:af1ded2a244cb47a96255b75f584a643edf6967e13bb5464b330ffdd9d7ad859
size 5284692

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:13d1bebabd79984fd6715971be758ef9a354495adea5e8d33f4e7904365e112b
size 5258380

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f33bc6810f0b91817a42610364cb49ed1b99660f058f0f9407e6f5920d0aee02
size 1008

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
size 33

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7b58d6c89e936a781a307805ebecf0dd473fbc02d52a7094da62e54bffb9454a
size 4344

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a08be578285cbe2d35b78f150d464ff3e10604a9865398c976983e0d711774f9
size 788528

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8
size 1166

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:34e36233477c8aa0b0840314ddace072062d4f486d06546bbd6550832c370065
size 247

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:66e7349a4a82ca6042a7189608d01eb1cfa38d100d039b5445ae1a9e65d824ab
size 14470946

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a2146f0c10c9f2611e57e617983aa4f91ad681b4fc50d91b992b97abd684f926
size 11662185

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5affbaf1c48895ba3c626e0d8cf1309e5f4ec6bbaa135313096f52a22de66c05
size 11410342

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6c2b195ca91b88fd16422128d386d2cabd808a1862c6d127e6bf2e83e1fe819a
size 448

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
size 33

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b360b6b956d2adcb20589947c553348ef1eb6b70743c989dcbe95243d8592ce5
size 4344

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3f5c3926b4d4da9271abefcdf6a8952bb1f13258a9c39fe0fd223f548dc89dcb
size 887728

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8
size 1166

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4993b05fb026619eec5eb70db8cadaa041ba4ab92d38b4a387167ace03b1018b
size 247

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bd25d17ef5b7500386761b5e32920879bbdcafe0e17a8a8845628525d861e644
size 10231081

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5b557acbfeb0681c0a38e47263d945f6cd3a03461298d8b17209c81e3fd0aae8
size 9701371

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:da8f3b4f9f965da63819652b2c042d4cf7e07d14631113ea072087d56370310e
size 10473741

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a053506017d8a78cfd307b2912eeafa1ac1485a280cf90913985fcc40120b5ec
size 416

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
size 33

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d6d172d1bca02face22ceb4c21ea2b054cf3463025485dce64711b6f36b31f8a
size 4344

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7e5ce817a2c188041f57f8d4c465dab3b9c3e4e1aeb7a9fb270230d1b36df530
size 1477064

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8
size 1166

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4eb2dc373e4ea7d474742590f9073d66a773f6ab94b9e73a8673df19f93fae6d
size 247

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d2c55b146fabe78b18c8a28a7746ab56e1ee7a6918e9e3dad9bd196f97975895
size 26158915

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:71e1958d77f56843acf1ec48da4f04311a5836c87a0e77dbe26aa47c27c6347e
size 18786848

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:20780718399b5759ff9a3a79824986310524793066198e3b9a307222f11a93df
size 17769988

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:279916f7689ae46af90e92a46eba9486a71fc762e3e2679ab5441eb37126827b
size 928

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
size 33

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7a7731051b521694b52b5631470720a7f05331915f4ac4e7f8cd83f9ff459bce
size 4344

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:99608258e8c9fe5191f1a12edc29b47d307790104149dffb6d3046ddad6aeb1b
size 435600

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8
size 1166

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ae6735b7b394914824e974a7461019373a10f9e2d84ddf834bec8ea268d9ec1e
size 247

Some files were not shown because too many files have changed in this diff Show More