Merge remote-tracking branch 'origin/main' into user/aliberts/2024_05_06_add_coverage
This commit is contained in:
commit
fc07f0e2bc
|
@ -1,2 +1,6 @@
|
||||||
*.memmap filter=lfs diff=lfs merge=lfs -text
|
*.memmap filter=lfs diff=lfs merge=lfs -text
|
||||||
*.stl filter=lfs diff=lfs merge=lfs -text
|
*.stl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.json filter=lfs diff=lfs merge=lfs -text
|
||||||
|
|
|
@ -10,7 +10,6 @@ on:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
PYTHON_VERSION: "3.10"
|
PYTHON_VERSION: "3.10"
|
||||||
# CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
latest-cpu:
|
latest-cpu:
|
||||||
|
@ -51,30 +50,6 @@ jobs:
|
||||||
tags: huggingface/lerobot-cpu
|
tags: huggingface/lerobot-cpu
|
||||||
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
|
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
|
||||||
|
|
||||||
# - name: Post to a Slack channel
|
|
||||||
# id: slack
|
|
||||||
# #uses: slackapi/slack-github-action@v1.25.0
|
|
||||||
# uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
|
||||||
# with:
|
|
||||||
# # Slack channel id, channel name, or user id to post message.
|
|
||||||
# # See also: https://api.slack.com/methods/chat.postMessage#channels
|
|
||||||
# channel-id: ${{ env.CI_SLACK_CHANNEL }}
|
|
||||||
# # For posting a rich message using Block Kit
|
|
||||||
# payload: |
|
|
||||||
# {
|
|
||||||
# "text": "lerobot-cpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
|
|
||||||
# "blocks": [
|
|
||||||
# {
|
|
||||||
# "type": "section",
|
|
||||||
# "text": {
|
|
||||||
# "type": "mrkdwn",
|
|
||||||
# "text": "lerobot-cpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
# }
|
|
||||||
# env:
|
|
||||||
# SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
|
||||||
|
|
||||||
latest-cuda:
|
latest-cuda:
|
||||||
name: GPU
|
name: GPU
|
||||||
|
@ -113,27 +88,40 @@ jobs:
|
||||||
tags: huggingface/lerobot-gpu
|
tags: huggingface/lerobot-gpu
|
||||||
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
|
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
|
||||||
|
|
||||||
# - name: Post to a Slack channel
|
|
||||||
# id: slack
|
latest-cuda-dev:
|
||||||
# #uses: slackapi/slack-github-action@v1.25.0
|
name: GPU Dev
|
||||||
# uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
runs-on: ubuntu-latest
|
||||||
# with:
|
steps:
|
||||||
# # Slack channel id, channel name, or user id to post message.
|
- name: Cleanup disk
|
||||||
# # See also: https://api.slack.com/methods/chat.postMessage#channels
|
run: |
|
||||||
# channel-id: ${{ env.CI_SLACK_CHANNEL }}
|
sudo df -h
|
||||||
# # For posting a rich message using Block Kit
|
# sudo ls -l /usr/local/lib/
|
||||||
# payload: |
|
# sudo ls -l /usr/share/
|
||||||
# {
|
sudo du -sh /usr/local/lib/
|
||||||
# "text": "lerobot-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
|
sudo du -sh /usr/share/
|
||||||
# "blocks": [
|
sudo rm -rf /usr/local/lib/android
|
||||||
# {
|
sudo rm -rf /usr/share/dotnet
|
||||||
# "type": "section",
|
sudo du -sh /usr/local/lib/
|
||||||
# "text": {
|
sudo du -sh /usr/share/
|
||||||
# "type": "mrkdwn",
|
sudo df -h
|
||||||
# "text": "lerobot-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
|
- name: Set up Docker Buildx
|
||||||
# }
|
uses: docker/setup-buildx-action@v3
|
||||||
# }
|
|
||||||
# ]
|
- name: Check out code
|
||||||
# }
|
uses: actions/checkout@v4
|
||||||
# env:
|
|
||||||
# SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
- name: Login to DockerHub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
|
|
||||||
|
- name: Build and Push GPU dev
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: ./docker/lerobot-gpu-dev/Dockerfile
|
||||||
|
push: true
|
||||||
|
tags: huggingface/lerobot-gpu:dev
|
||||||
|
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
|
||||||
|
|
|
@ -29,6 +29,8 @@ jobs:
|
||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
lfs: true # Ensure LFS files are pulled
|
||||||
|
|
||||||
- name: Install EGL
|
- name: Install EGL
|
||||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||||
|
@ -57,6 +59,40 @@ jobs:
|
||||||
&& rm -rf tests/outputs outputs
|
&& rm -rf tests/outputs outputs
|
||||||
|
|
||||||
|
|
||||||
|
pytest-minimal:
|
||||||
|
name: Pytest (minimal install)
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
DATA_DIR: tests/data
|
||||||
|
MUJOCO_GL: egl
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
lfs: true # Ensure LFS files are pulled
|
||||||
|
|
||||||
|
- name: Install poetry
|
||||||
|
run: |
|
||||||
|
pipx install poetry && poetry config virtualenvs.in-project true
|
||||||
|
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||||
|
|
||||||
|
- name: Set up Python 3.10
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install poetry dependencies
|
||||||
|
run: |
|
||||||
|
poetry install --extras "test"
|
||||||
|
|
||||||
|
- name: Test with pytest
|
||||||
|
run: |
|
||||||
|
pytest tests -v --cov=./lerobot --durations=0 \
|
||||||
|
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
|
||||||
|
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
|
||||||
|
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||||
|
&& rm -rf tests/outputs outputs
|
||||||
|
|
||||||
|
|
||||||
end-to-end:
|
end-to-end:
|
||||||
name: End-to-end
|
name: End-to-end
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
@ -65,6 +101,8 @@ jobs:
|
||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
lfs: true # Ensure LFS files are pulled
|
||||||
|
|
||||||
- name: Install EGL
|
- name: Install EGL
|
||||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||||
|
|
|
@ -2,12 +2,17 @@
|
||||||
logs
|
logs
|
||||||
tmp
|
tmp
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
|
# Data
|
||||||
data
|
data
|
||||||
outputs
|
outputs
|
||||||
.vscode
|
|
||||||
rl
|
# Apple
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
|
# VS Code
|
||||||
|
.vscode
|
||||||
|
|
||||||
# HPC
|
# HPC
|
||||||
nautilus/*.yaml
|
nautilus/*.yaml
|
||||||
*.key
|
*.key
|
||||||
|
@ -90,6 +95,7 @@ instance/
|
||||||
docs/_build/
|
docs/_build/
|
||||||
|
|
||||||
# PyBuilder
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
target/
|
target/
|
||||||
|
|
||||||
# Jupyter Notebook
|
# Jupyter Notebook
|
||||||
|
@ -102,13 +108,6 @@ ipython_config.py
|
||||||
# pyenv
|
# pyenv
|
||||||
.python-version
|
.python-version
|
||||||
|
|
||||||
# pipenv
|
|
||||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
||||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
||||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
||||||
# install all needed dependencies.
|
|
||||||
#Pipfile.lock
|
|
||||||
|
|
||||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
__pypackages__/
|
__pypackages__/
|
||||||
|
|
||||||
|
@ -119,6 +118,15 @@ celerybeat.pid
|
||||||
# SageMath parsed files
|
# SageMath parsed files
|
||||||
*.sage.py
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
# Spyder project settings
|
# Spyder project settings
|
||||||
.spyderproject
|
.spyderproject
|
||||||
.spyproject
|
.spyproject
|
||||||
|
@ -136,3 +144,9 @@ dmypy.json
|
||||||
|
|
||||||
# Pyre type checker
|
# Pyre type checker
|
||||||
.pyre/
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
|
@ -18,7 +18,7 @@ repos:
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.4.2
|
rev: v0.4.3
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
|
|
|
@ -195,6 +195,11 @@ Follow these steps to start contributing:
|
||||||
git commit
|
git commit
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Note, if you already commited some changes that have a wrong formatting, you can use:
|
||||||
|
```bash
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||||
|
|
||||||
It is a good idea to sync your copy of the code with the original
|
It is a good idea to sync your copy of the code with the original
|
||||||
|
|
64
Makefile
64
Makefile
|
@ -20,16 +20,19 @@ build-gpu:
|
||||||
test-end-to-end:
|
test-end-to-end:
|
||||||
${MAKE} test-act-ete-train
|
${MAKE} test-act-ete-train
|
||||||
${MAKE} test-act-ete-eval
|
${MAKE} test-act-ete-eval
|
||||||
|
${MAKE} test-act-ete-train-amp
|
||||||
|
${MAKE} test-act-ete-eval-amp
|
||||||
${MAKE} test-diffusion-ete-train
|
${MAKE} test-diffusion-ete-train
|
||||||
${MAKE} test-diffusion-ete-eval
|
${MAKE} test-diffusion-ete-eval
|
||||||
# TODO(rcadene, alexander-soare): enable end-to-end tests for tdmpc
|
${MAKE} test-tdmpc-ete-train
|
||||||
# ${MAKE} test-tdmpc-ete-train
|
${MAKE} test-tdmpc-ete-eval
|
||||||
# ${MAKE} test-tdmpc-ete-eval
|
|
||||||
${MAKE} test-default-ete-eval
|
${MAKE} test-default-ete-eval
|
||||||
|
${MAKE} test-act-pusht-tutorial
|
||||||
|
|
||||||
test-act-ete-train:
|
test-act-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
policy=act \
|
policy=act \
|
||||||
|
policy.dim_model=64 \
|
||||||
env=aloha \
|
env=aloha \
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
training.offline_steps=2 \
|
training.offline_steps=2 \
|
||||||
|
@ -52,9 +55,40 @@ test-act-ete-eval:
|
||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
|
|
||||||
|
test-act-ete-train-amp:
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
policy=act \
|
||||||
|
policy.dim_model=64 \
|
||||||
|
env=aloha \
|
||||||
|
wandb.enable=False \
|
||||||
|
training.offline_steps=2 \
|
||||||
|
training.online_steps=0 \
|
||||||
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
|
device=cpu \
|
||||||
|
training.save_model=true \
|
||||||
|
training.save_freq=2 \
|
||||||
|
policy.n_action_steps=20 \
|
||||||
|
policy.chunk_size=20 \
|
||||||
|
training.batch_size=2 \
|
||||||
|
hydra.run.dir=tests/outputs/act/ \
|
||||||
|
use_amp=true
|
||||||
|
|
||||||
|
test-act-ete-eval-amp:
|
||||||
|
python lerobot/scripts/eval.py \
|
||||||
|
-p tests/outputs/act/checkpoints/000002 \
|
||||||
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
|
env.episode_length=8 \
|
||||||
|
device=cpu \
|
||||||
|
use_amp=true
|
||||||
|
|
||||||
test-diffusion-ete-train:
|
test-diffusion-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
policy=diffusion \
|
policy=diffusion \
|
||||||
|
policy.down_dims=\[64,128,256\] \
|
||||||
|
policy.diffusion_step_embed_dim=32 \
|
||||||
|
policy.num_inference_steps=10 \
|
||||||
env=pusht \
|
env=pusht \
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
training.offline_steps=2 \
|
training.offline_steps=2 \
|
||||||
|
@ -75,15 +109,16 @@ test-diffusion-ete-eval:
|
||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
|
|
||||||
|
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
|
||||||
test-tdmpc-ete-train:
|
test-tdmpc-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
policy=tdmpc \
|
policy=tdmpc \
|
||||||
env=xarm \
|
env=xarm \
|
||||||
env.task=XarmLift-v0 \
|
env.task=XarmLift-v0 \
|
||||||
dataset_repo_id=lerobot/xarm_lift_medium_replay \
|
dataset_repo_id=lerobot/xarm_lift_medium \
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
training.offline_steps=2 \
|
training.offline_steps=2 \
|
||||||
training.online_steps=2 \
|
training.online_steps=0 \
|
||||||
eval.n_episodes=1 \
|
eval.n_episodes=1 \
|
||||||
eval.batch_size=1 \
|
eval.batch_size=1 \
|
||||||
env.episode_length=2 \
|
env.episode_length=2 \
|
||||||
|
@ -101,7 +136,6 @@ test-tdmpc-ete-eval:
|
||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
|
|
||||||
|
|
||||||
test-default-ete-eval:
|
test-default-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--config lerobot/configs/default.yaml \
|
--config lerobot/configs/default.yaml \
|
||||||
|
@ -109,3 +143,21 @@ test-default-ete-eval:
|
||||||
eval.batch_size=1 \
|
eval.batch_size=1 \
|
||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
|
|
||||||
|
|
||||||
|
test-act-pusht-tutorial:
|
||||||
|
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/created_by_Makefile.yaml
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
policy=created_by_Makefile.yaml \
|
||||||
|
env=pusht \
|
||||||
|
wandb.enable=False \
|
||||||
|
training.offline_steps=2 \
|
||||||
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
|
env.episode_length=2 \
|
||||||
|
device=cpu \
|
||||||
|
training.save_model=true \
|
||||||
|
training.save_freq=2 \
|
||||||
|
training.batch_size=2 \
|
||||||
|
hydra.run.dir=tests/outputs/act_pusht/
|
||||||
|
rm lerobot/configs/policy/created_by_Makefile.yaml
|
||||||
|
|
43
README.md
43
README.md
|
@ -57,7 +57,6 @@
|
||||||
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||||
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||||
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||||
- Thanks to Vincent Moens and colleagues for open sourcing [TorchRL](https://github.com/pytorch/rl). It allowed for quick experimentations on the design of `LeRobot`.
|
|
||||||
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,6 +77,10 @@ Install 🤗 LeRobot:
|
||||||
pip install .
|
pip install .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **NOTE:** Depending on your platform, If you encounter any build errors during this step
|
||||||
|
you may need to install `cmake` and `build-essential` for building some of our dependencies.
|
||||||
|
On linux: `sudo apt-get install cmake build-essential`
|
||||||
|
|
||||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
||||||
- [aloha](https://github.com/huggingface/gym-aloha)
|
- [aloha](https://github.com/huggingface/gym-aloha)
|
||||||
- [xarm](https://github.com/huggingface/gym-xarm)
|
- [xarm](https://github.com/huggingface/gym-xarm)
|
||||||
|
@ -93,11 +96,14 @@ To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tra
|
||||||
wandb login
|
wandb login
|
||||||
```
|
```
|
||||||
|
|
||||||
|
(note: you will also need to enable WandB in the configuration. See below.)
|
||||||
|
|
||||||
## Walkthrough
|
## Walkthrough
|
||||||
|
|
||||||
```
|
```
|
||||||
.
|
.
|
||||||
├── examples # contains demonstration examples, start here to learn about LeRobot
|
├── examples # contains demonstration examples, start here to learn about LeRobot
|
||||||
|
| └── advanced # contains even more examples for those who have mastered the basics
|
||||||
├── lerobot
|
├── lerobot
|
||||||
| ├── configs # contains hydra yaml files with all options that you can override in the command line
|
| ├── configs # contains hydra yaml files with all options that you can override in the command line
|
||||||
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
|
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
|
||||||
|
@ -157,15 +163,16 @@ See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||||
|
|
||||||
### Train your own policy
|
### Train your own policy
|
||||||
|
|
||||||
Check out [example 3](./examples/3_train_policy.py) that illustrates how to start training a model.
|
Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
|
||||||
|
|
||||||
|
In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
|
||||||
|
|
||||||
In general, you can use our training script to easily train any policy. To use wandb for logging training and evaluation curves, make sure you ran `wandb login`. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
|
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
policy=act \
|
policy=act \
|
||||||
env=aloha \
|
env=aloha \
|
||||||
env.task=AlohaInsertion-v0 \
|
env.task=AlohaInsertion-v0 \
|
||||||
dataset_repo_id=lerobot/aloha_sim_insertion_human
|
dataset_repo_id=lerobot/aloha_sim_insertion_human \
|
||||||
```
|
```
|
||||||
|
|
||||||
The experiment directory is automatically generated and will show up in yellow in your terminal. It looks like `outputs/train/2024-05-05/20-21-12_aloha_act_default`. You can manually specify an experiment directory by adding this argument to the `train.py` python command:
|
The experiment directory is automatically generated and will show up in yellow in your terminal. It looks like `outputs/train/2024-05-05/20-21-12_aloha_act_default`. You can manually specify an experiment directory by adding this argument to the `train.py` python command:
|
||||||
|
@ -173,17 +180,29 @@ The experiment directory is automatically generated and will show up in yellow i
|
||||||
hydra.run.dir=your/new/experiment/dir
|
hydra.run.dir=your/new/experiment/dir
|
||||||
```
|
```
|
||||||
|
|
||||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of logs from wandb:
|
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
|
||||||

|
|
||||||
|
|
||||||
You can deactivate wandb by adding these arguments to the `train.py` python command:
|
|
||||||
```bash
|
```bash
|
||||||
wandb.disable_artifact=true \
|
wandb.enable=true
|
||||||
wandb.enable=false
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. After training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
|
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||||
|
|
||||||
|
#### Reproduce state-of-the-art (SOTA)
|
||||||
|
|
||||||
|
We have organized our configuration files (found under [`lerobot/configs`](./lerobot/configs)) such that they reproduce SOTA results from a given model variant in their respective original works. Simply running:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py policy=diffusion env=pusht
|
||||||
|
```
|
||||||
|
|
||||||
|
reproduces SOTA results for Diffusion Policy on the PushT task.
|
||||||
|
|
||||||
|
Pretrained policies, along with reproduction details, can be found under the "Models" section of https://huggingface.co/lerobot.
|
||||||
|
|
||||||
## Contribute
|
## Contribute
|
||||||
|
|
||||||
|
@ -196,11 +215,11 @@ To add a dataset to the hub, you need to login using a write-access token, which
|
||||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||||
```
|
```
|
||||||
|
|
||||||
Then move your dataset folder in `data` directory (e.g. `data/aloha_ping_pong`), and push your dataset to the hub with:
|
Then move your dataset folder in `data` directory (e.g. `data/aloha_static_pingpong_test`), and push your dataset to the hub with:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--data-dir data \
|
--data-dir data \
|
||||||
--dataset-id aloha_ping_ping \
|
--dataset-id aloha_static_pingpong_test \
|
||||||
--raw-format aloha_hdf5 \
|
--raw-format aloha_hdf5 \
|
||||||
--community-id lerobot
|
--community-id lerobot
|
||||||
```
|
```
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
||||||
|
|
||||||
|
# Configure image
|
||||||
|
ARG PYTHON_VERSION=3.10
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# Install apt dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
build-essential cmake \
|
||||||
|
git git-lfs openssh-client \
|
||||||
|
nano vim less util-linux \
|
||||||
|
htop atop nvtop \
|
||||||
|
sed gawk grep curl wget \
|
||||||
|
tcpdump sysstat screen tmux \
|
||||||
|
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||||
|
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||||
|
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install gh cli tool
|
||||||
|
RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \
|
||||||
|
&& mkdir -p -m 755 /etc/apt/keyrings \
|
||||||
|
&& wget -qO- https://cli.github.com/packages/githubcli-archive-keyring.gpg | tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \
|
||||||
|
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||||
|
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
|
||||||
|
&& apt update \
|
||||||
|
&& apt install gh -y \
|
||||||
|
&& apt clean && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Setup `python`
|
||||||
|
RUN ln -s /usr/bin/python3 /usr/bin/python
|
||||||
|
|
||||||
|
# Install poetry
|
||||||
|
RUN curl -sSL https://install.python-poetry.org | python -
|
||||||
|
ENV PATH="/root/.local/bin:$PATH"
|
||||||
|
RUN echo 'if [ "$HOME" != "/root" ]; then ln -sf /root/.local/bin/poetry $HOME/.local/bin/poetry; fi' >> /root/.bashrc
|
||||||
|
RUN poetry config virtualenvs.create false
|
||||||
|
RUN poetry config virtualenvs.in-project true
|
||||||
|
|
||||||
|
# Set EGL as the rendering backend for MuJoCo
|
||||||
|
ENV MUJOCO_GL="egl"
|
|
@ -4,6 +4,7 @@ FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
||||||
ARG PYTHON_VERSION=3.10
|
ARG PYTHON_VERSION=3.10
|
||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
|
||||||
# Install apt dependencies
|
# Install apt dependencies
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
build-essential cmake \
|
build-essential cmake \
|
||||||
|
@ -11,6 +12,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
|
||||||
# Create virtual environment
|
# Create virtual environment
|
||||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||||
RUN python -m venv /opt/venv
|
RUN python -m venv /opt/venv
|
||||||
|
|
|
@ -0,0 +1,165 @@
|
||||||
|
This tutorial will explain the training script, how to use it, and particularly the use of Hydra to configure everything needed for the training run.
|
||||||
|
|
||||||
|
## The training script
|
||||||
|
|
||||||
|
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
|
||||||
|
|
||||||
|
- Loads a Hydra configuration file for the following steps (more on Hydra in a moment).
|
||||||
|
- Makes a simulation environment.
|
||||||
|
- Makes a dataset corresponding to that simulation environment.
|
||||||
|
- Makes a policy.
|
||||||
|
- Runs a standard training loop with forward pass, backward pass, optimization step, and occasional logging, evaluation (of the policy on the environment), and checkpointing.
|
||||||
|
|
||||||
|
## Our use of Hydra
|
||||||
|
|
||||||
|
Explaining the ins and outs of [Hydra](https://hydra.cc/docs/intro/) is beyond the scope of this document, but here we'll share the main points you need to know.
|
||||||
|
|
||||||
|
First, `lerobot/configs` has a directory structure like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
.
|
||||||
|
├── default.yaml
|
||||||
|
├── env
|
||||||
|
│ ├── aloha.yaml
|
||||||
|
│ ├── pusht.yaml
|
||||||
|
│ └── xarm.yaml
|
||||||
|
└── policy
|
||||||
|
├── act.yaml
|
||||||
|
├── diffusion.yaml
|
||||||
|
└── tdmpc.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
**_For brevity, in the rest of this document we'll drop the leading `lerobot/configs` path. So `default.yaml` really refers to `lerobot/configs/default.yaml`._**
|
||||||
|
|
||||||
|
When you run the training script with
|
||||||
|
|
||||||
|
```python
|
||||||
|
python lerobot/scripts/train.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Hydra is set up to read `default.yaml` (via the `@hydra.main` decorator). If you take a look at the `@hydra.main`'s arguments you will see `config_path="../configs", config_name="default"`. At the top of `default.yaml`, is a `defaults` section which looks likes this:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
defaults:
|
||||||
|
- _self_
|
||||||
|
- env: pusht
|
||||||
|
- policy: diffusion
|
||||||
|
```
|
||||||
|
|
||||||
|
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overriden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
|
||||||
|
|
||||||
|
Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`.
|
||||||
|
|
||||||
|
Thanks to this `defaults` section in `default.yaml`, if you want to train Diffusion Policy with PushT, you really only need to run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py
|
||||||
|
```
|
||||||
|
|
||||||
|
However, you can be more explicit and launch the exact same Diffusion Policy training on PushT with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py policy=diffusion env=pusht
|
||||||
|
```
|
||||||
|
|
||||||
|
This way of overriding defaults via the CLI is especially useful when you want to change the policy and/or environment. For instance, you can train ACT on the default Aloha environment with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py policy=act env=aloha
|
||||||
|
```
|
||||||
|
|
||||||
|
There are two things to note here:
|
||||||
|
- Config overrides are passed as `param_name=param_value`.
|
||||||
|
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/pusht.yaml`.
|
||||||
|
|
||||||
|
_As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._
|
||||||
|
|
||||||
|
## Overriding configuration parameters in the CLI
|
||||||
|
|
||||||
|
Now let's say that we want to train on a different task in the Aloha environment. If you look in `env/aloha.yaml` you will see something like:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# lerobot/configs/env/aloha.yaml
|
||||||
|
env:
|
||||||
|
task: AlohaInsertion-v0
|
||||||
|
```
|
||||||
|
|
||||||
|
And if you look in `policy/act.yaml` you will see something like:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# lerobot/configs/policy/act.yaml
|
||||||
|
dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||||
|
```
|
||||||
|
|
||||||
|
But our Aloha environment actually supports a cube transfer task as well. To train for this task, you could manually modify the two yaml configuration files respectively.
|
||||||
|
|
||||||
|
First, we'd need to switch to using the cube transfer task for the ALOHA environment.
|
||||||
|
|
||||||
|
```diff
|
||||||
|
# lerobot/configs/env/aloha.yaml
|
||||||
|
env:
|
||||||
|
- task: AlohaInsertion-v0
|
||||||
|
+ task: AlohaTransferCube-v0
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, we'd also need to switch to using the cube transfer dataset.
|
||||||
|
|
||||||
|
```diff
|
||||||
|
# lerobot/configs/policy/act.yaml
|
||||||
|
-dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||||
|
+dataset_repo_id: lerobot/aloha_sim_transfer_cube_human
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, you'd be able to run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py policy=act env=aloha
|
||||||
|
```
|
||||||
|
|
||||||
|
and you'd be training and evaluating on the cube transfer task.
|
||||||
|
|
||||||
|
An alternative approach to editing the yaml configuration files, would be to override the defaults via the command line:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
policy=act \
|
||||||
|
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||||
|
env=aloha \
|
||||||
|
env.task=AlohaTransferCube-v0
|
||||||
|
```
|
||||||
|
|
||||||
|
There's something new here. Notice the `.` delimiter used to traverse the configuration hierarchy. _But be aware that the `defaults` section is an exception. As you saw above, we didn't need to write `defaults.policy=act` in the CLI. `policy=act` was enough._
|
||||||
|
|
||||||
|
Putting all that knowledge together, here's the command that was used to train https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human \
|
||||||
|
device=cuda
|
||||||
|
env=aloha \
|
||||||
|
env.task=AlohaTransferCube-v0 \
|
||||||
|
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||||
|
policy=act \
|
||||||
|
training.eval_freq=10000 \
|
||||||
|
training.log_freq=250 \
|
||||||
|
training.offline_steps=100000 \
|
||||||
|
training.save_model=true \
|
||||||
|
training.save_freq=25000 \
|
||||||
|
eval.n_episodes=50 \
|
||||||
|
eval.batch_size=50 \
|
||||||
|
wandb.enable=false \
|
||||||
|
```
|
||||||
|
|
||||||
|
There's one new thing here: `hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human`, which specifies where to save the training output.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
So far we've seen how to train Diffusion Policy for PushT and ACT for ALOHA. Now, what if we want to train ACT for PushT? Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
|
||||||
|
```
|
||||||
|
|
||||||
|
Please, head on over to our [advanced tutorial on adapting policy configuration to various environments](./advanced/train_act_pusht/train_act_pusht.md) to learn more.
|
||||||
|
|
||||||
|
Or in the meantime, happy coding! 🤗
|
|
@ -0,0 +1,87 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Change the seed to match what PushT eval uses
|
||||||
|
# (to avoid evaluating on seeds used for generating the training data).
|
||||||
|
seed: 100000
|
||||||
|
# Change the dataset repository to the PushT one.
|
||||||
|
dataset_repo_id: lerobot/pusht
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
observation.image:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 80000
|
||||||
|
online_steps: 0
|
||||||
|
eval_freq: 10000
|
||||||
|
save_freq: 100000
|
||||||
|
log_freq: 250
|
||||||
|
save_model: true
|
||||||
|
|
||||||
|
batch_size: 8
|
||||||
|
lr: 1e-5
|
||||||
|
lr_backbone: 1e-5
|
||||||
|
weight_decay: 1e-4
|
||||||
|
grad_clip_norm: 10
|
||||||
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
||||||
|
eval:
|
||||||
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
|
# See `configuration_act.py` for more details.
|
||||||
|
policy:
|
||||||
|
name: act
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_obs_steps: 1
|
||||||
|
chunk_size: 100 # chunk_size
|
||||||
|
n_action_steps: 100
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
observation.image: [3, 96, 96]
|
||||||
|
observation.state: ["${env.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.image: mean_std
|
||||||
|
# Use min_max normalization just because it's more standard.
|
||||||
|
observation.state: min_max
|
||||||
|
output_normalization_modes:
|
||||||
|
# Use min_max normalization just because it's more standard.
|
||||||
|
action: min_max
|
||||||
|
|
||||||
|
# Architecture.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: resnet18
|
||||||
|
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||||
|
replace_final_stride_with_dilation: false
|
||||||
|
# Transformer layers.
|
||||||
|
pre_norm: false
|
||||||
|
dim_model: 512
|
||||||
|
n_heads: 8
|
||||||
|
dim_feedforward: 3200
|
||||||
|
feedforward_activation: relu
|
||||||
|
n_encoder_layers: 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
|
n_decoder_layers: 1
|
||||||
|
# VAE.
|
||||||
|
use_vae: true
|
||||||
|
latent_dim: 32
|
||||||
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
|
# Inference.
|
||||||
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
|
# Training and loss computation.
|
||||||
|
dropout: 0.1
|
||||||
|
kl_weight: 10.0
|
|
@ -0,0 +1,70 @@
|
||||||
|
In this tutorial we will learn how to adapt a policy configuration to be compatible with a new environment and dataset. As a concrete example, we will adapt the default configuration for ACT to be compatible with the PushT environment and dataset.
|
||||||
|
|
||||||
|
If you haven't already read our tutorial on the [training script and configuration tooling](../4_train_policy_with_script.md) please do so prior to tackling this tutorial.
|
||||||
|
|
||||||
|
Let's get started!
|
||||||
|
|
||||||
|
Suppose we want to train ACT for PushT. Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
|
||||||
|
```
|
||||||
|
|
||||||
|
We need to adapt the parameters of the ACT policy configuration to the PushT environment. The most important ones are the image keys.
|
||||||
|
|
||||||
|
ALOHA's datasets and environments typically use a variable number of cameras. In `lerobot/configs/policy/act.yaml` you may notice two relevant sections. Here we show you the minimal diff needed to adjust to PushT:
|
||||||
|
|
||||||
|
```diff
|
||||||
|
override_dataset_stats:
|
||||||
|
- observation.images.top:
|
||||||
|
+ observation.image:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
|
policy:
|
||||||
|
input_shapes:
|
||||||
|
- observation.images.top: [3, 480, 640]
|
||||||
|
+ observation.image: [3, 96, 96]
|
||||||
|
observation.state: ["${env.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
input_normalization_modes:
|
||||||
|
- observation.images.top: mean_std
|
||||||
|
+ observation.image: mean_std
|
||||||
|
observation.state: min_max
|
||||||
|
output_normalization_modes:
|
||||||
|
action: min_max
|
||||||
|
```
|
||||||
|
|
||||||
|
Here we've accounted for the following:
|
||||||
|
- PushT uses "observation.image" for its image key.
|
||||||
|
- PushT provides smaller images.
|
||||||
|
|
||||||
|
_Side note: technically we could override these via the CLI, but with many changes it gets a bit messy, and we also have a bit of a challenge in that we're using `.` in our observation keys which is treated by Hydra as a hierarchical separator_.
|
||||||
|
|
||||||
|
For your convenience, we provide [`act_pusht.yaml`](./act_pusht.yaml) in this directory. It contains the diff above, plus some other (optional) ones that are explained within. Please copy it into `lerobot/configs/policy` with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/act_pusht.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
(remember from a [previous tutorial](../4_train_policy_with_script.md) that Hydra will look in the `lerobot/configs` directory). Now try running the following.
|
||||||
|
|
||||||
|
<!-- Note to contributor: are you changing this command? Note that it's tested in `Makefile`, so change it there too! -->
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py policy=act_pusht env=pusht
|
||||||
|
```
|
||||||
|
|
||||||
|
Notice that this is much the same as the command that failed at the start of the tutorial, only:
|
||||||
|
- Now we are using `policy=act_pusht` to point to our new configuration file.
|
||||||
|
- We can drop `dataset_repo_id=lerobot/pusht` as the change is incorporated in our new configuration file.
|
||||||
|
|
||||||
|
Hurrah! You're now training ACT for the PushT environment.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
The bottom line of this tutorial is that when training policies for different environments and datasets you will need to understand what parts of the policy configuration are specific to those and make changes accordingly.
|
||||||
|
|
||||||
|
Happy coding! 🤗
|
|
@ -0,0 +1,90 @@
|
||||||
|
"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
|
||||||
|
|
||||||
|
This technique can be useful for debugging and testing purposes, as well as identifying whether a policy
|
||||||
|
is learning effectively.
|
||||||
|
|
||||||
|
Furthermore, relying on validation loss to evaluate performance is generally not considered a good practice,
|
||||||
|
especially in the context of imitation learning. The most reliable approach is to evaluate the policy directly
|
||||||
|
on the target environment, whether that be in simulation or the real world.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
# Download the diffusion policy for pusht environment
|
||||||
|
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||||
|
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||||
|
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||||
|
|
||||||
|
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||||
|
policy.eval()
|
||||||
|
policy.to(device)
|
||||||
|
|
||||||
|
# Set up the dataset.
|
||||||
|
delta_timestamps = {
|
||||||
|
# Load the previous image and state at -0.1 seconds before current frame,
|
||||||
|
# then load current image and state corresponding to 0.0 second.
|
||||||
|
"observation.image": [-0.1, 0.0],
|
||||||
|
"observation.state": [-0.1, 0.0],
|
||||||
|
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||||
|
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||||
|
# used to calculate the loss.
|
||||||
|
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load the last 10% of episodes of the dataset as a validation set.
|
||||||
|
# - Load full dataset
|
||||||
|
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
|
||||||
|
# - Calculate train and val subsets
|
||||||
|
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
|
||||||
|
num_val_episodes = full_dataset.num_episodes - num_train_episodes
|
||||||
|
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
|
||||||
|
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
|
||||||
|
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
|
||||||
|
# - Get first frame index of the validation set
|
||||||
|
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
|
||||||
|
# - Load frames subset belonging to validation set using the `split` argument.
|
||||||
|
# It utilizes the `datasets` library's syntax for slicing datasets.
|
||||||
|
# For more information on the Slice API, please see:
|
||||||
|
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||||
|
train_dataset = LeRobotDataset(
|
||||||
|
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
|
||||||
|
)
|
||||||
|
val_dataset = LeRobotDataset(
|
||||||
|
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
|
||||||
|
)
|
||||||
|
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||||
|
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||||
|
|
||||||
|
# Create dataloader for evaluation.
|
||||||
|
val_dataloader = torch.utils.data.DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
num_workers=4,
|
||||||
|
batch_size=64,
|
||||||
|
shuffle=False,
|
||||||
|
pin_memory=device != torch.device("cpu"),
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run validation loop.
|
||||||
|
loss_cumsum = 0
|
||||||
|
n_examples_evaluated = 0
|
||||||
|
for batch in val_dataloader:
|
||||||
|
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||||
|
output_dict = policy.forward(batch)
|
||||||
|
|
||||||
|
loss_cumsum += output_dict["loss"].item()
|
||||||
|
n_examples_evaluated += batch["index"].shape[0]
|
||||||
|
|
||||||
|
# Calculate the average loss over the validation set.
|
||||||
|
average_loss = loss_cumsum / n_examples_evaluated
|
||||||
|
|
||||||
|
print(f"Average loss on validation set: {average_loss:.4f}")
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
|
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
|
||||||
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
|
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
|
||||||
|
@ -46,13 +61,21 @@ available_datasets_per_env = {
|
||||||
"lerobot/aloha_sim_insertion_scripted",
|
"lerobot/aloha_sim_insertion_scripted",
|
||||||
"lerobot/aloha_sim_transfer_cube_human",
|
"lerobot/aloha_sim_transfer_cube_human",
|
||||||
"lerobot/aloha_sim_transfer_cube_scripted",
|
"lerobot/aloha_sim_transfer_cube_scripted",
|
||||||
|
"lerobot/aloha_sim_insertion_human_image",
|
||||||
|
"lerobot/aloha_sim_insertion_scripted_image",
|
||||||
|
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||||
|
"lerobot/aloha_sim_transfer_cube_scripted_image",
|
||||||
],
|
],
|
||||||
"pusht": ["lerobot/pusht"],
|
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
|
||||||
"xarm": [
|
"xarm": [
|
||||||
"lerobot/xarm_lift_medium",
|
"lerobot/xarm_lift_medium",
|
||||||
"lerobot/xarm_lift_medium_replay",
|
"lerobot/xarm_lift_medium_replay",
|
||||||
"lerobot/xarm_push_medium",
|
"lerobot/xarm_push_medium",
|
||||||
"lerobot/xarm_push_medium_replay",
|
"lerobot/xarm_push_medium_replay",
|
||||||
|
"lerobot/xarm_lift_medium_image",
|
||||||
|
"lerobot/xarm_lift_medium_replay_image",
|
||||||
|
"lerobot/xarm_push_medium_image",
|
||||||
|
"lerobot/xarm_push_medium_replay_image",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""To enable `lerobot.__version__`"""
|
"""To enable `lerobot.__version__`"""
|
||||||
|
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -5,17 +20,19 @@ import datasets
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
load_episode_data_index,
|
load_episode_data_index,
|
||||||
load_hf_dataset,
|
load_hf_dataset,
|
||||||
load_info,
|
load_info,
|
||||||
load_previous_and_future_frames,
|
load_previous_and_future_frames,
|
||||||
load_stats,
|
load_stats,
|
||||||
load_videos,
|
load_videos,
|
||||||
|
reset_episode_index,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||||
|
|
||||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||||
CODEBASE_VERSION = "v1.3"
|
CODEBASE_VERSION = "v1.4"
|
||||||
|
|
||||||
|
|
||||||
class LeRobotDataset(torch.utils.data.Dataset):
|
class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
@ -39,7 +56,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# TODO(rcadene, aliberts): implement faster transfer
|
# TODO(rcadene, aliberts): implement faster transfer
|
||||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||||
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
|
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
|
||||||
|
if split == "train":
|
||||||
self.episode_data_index = load_episode_data_index(repo_id, version, root)
|
self.episode_data_index = load_episode_data_index(repo_id, version, root)
|
||||||
|
else:
|
||||||
|
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||||
|
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||||
self.stats = load_stats(repo_id, version, root)
|
self.stats = load_stats(repo_id, version, root)
|
||||||
self.info = load_info(repo_id, version, root)
|
self.info = load_info(repo_id, version, root)
|
||||||
if self.video:
|
if self.video:
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
|
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
|
||||||
|
|
||||||
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.
|
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
This file contains all obsolete download scripts. They are centralized here to not have to load
|
This file contains all obsolete download scripts. They are centralized here to not have to load
|
||||||
useless dependencies when using datasets.
|
useless dependencies when using datasets.
|
||||||
|
@ -9,17 +24,16 @@ import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
ALOHA_RAW_URLS_DIR = "lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls"
|
|
||||||
|
|
||||||
|
|
||||||
def download_raw(raw_dir, dataset_id):
|
def download_raw(raw_dir, dataset_id):
|
||||||
if "pusht" in dataset_id:
|
if "aloha" in dataset_id or "image" in dataset_id:
|
||||||
|
download_hub(raw_dir, dataset_id)
|
||||||
|
elif "pusht" in dataset_id:
|
||||||
download_pusht(raw_dir)
|
download_pusht(raw_dir)
|
||||||
elif "xarm" in dataset_id:
|
elif "xarm" in dataset_id:
|
||||||
download_xarm(raw_dir)
|
download_xarm(raw_dir)
|
||||||
elif "aloha" in dataset_id:
|
|
||||||
download_aloha(raw_dir, dataset_id)
|
|
||||||
elif "umi" in dataset_id:
|
elif "umi" in dataset_id:
|
||||||
download_umi(raw_dir)
|
download_umi(raw_dir)
|
||||||
else:
|
else:
|
||||||
|
@ -88,37 +102,13 @@ def download_xarm(raw_dir: Path):
|
||||||
zip_path.unlink()
|
zip_path.unlink()
|
||||||
|
|
||||||
|
|
||||||
def download_aloha(raw_dir: Path, dataset_id: str):
|
def download_hub(raw_dir: Path, dataset_id: str):
|
||||||
import gdown
|
|
||||||
|
|
||||||
subset_id = dataset_id.replace("aloha_", "")
|
|
||||||
urls_path = Path(ALOHA_RAW_URLS_DIR) / f"{subset_id}.txt"
|
|
||||||
assert urls_path.exists(), f"{subset_id}.txt not found in '{ALOHA_RAW_URLS_DIR}' directory."
|
|
||||||
|
|
||||||
with open(urls_path) as f:
|
|
||||||
# strip lines and ignore empty lines
|
|
||||||
urls = [url.strip() for url in f if url.strip()]
|
|
||||||
|
|
||||||
# sanity check
|
|
||||||
for url in urls:
|
|
||||||
assert (
|
|
||||||
"drive.google.com/drive/folders" in url or "drive.google.com/file" in url
|
|
||||||
), f"Wrong url provided '{url}' in file '{urls_path}'."
|
|
||||||
|
|
||||||
raw_dir = Path(raw_dir)
|
raw_dir = Path(raw_dir)
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
logging.info(f"Start downloading from google drive for {dataset_id}")
|
logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}")
|
||||||
for url in urls:
|
snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir)
|
||||||
if "drive.google.com/drive/folders" in url:
|
logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}")
|
||||||
# when a folder url is given, download up to 50 files from the folder
|
|
||||||
gdown.download_folder(url, output=str(raw_dir), remaining_ok=True)
|
|
||||||
|
|
||||||
elif "drive.google.com/file" in url:
|
|
||||||
# because of the 50 files limit per folder, we download the remaining files (file by file)
|
|
||||||
gdown.download(url, output=str(raw_dir), fuzzy=True)
|
|
||||||
|
|
||||||
logging.info(f"End downloading from google drive for {dataset_id}")
|
|
||||||
|
|
||||||
|
|
||||||
def download_umi(raw_dir: Path):
|
def download_umi(raw_dir: Path):
|
||||||
|
@ -133,21 +123,30 @@ def download_umi(raw_dir: Path):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
data_dir = Path("data")
|
data_dir = Path("data")
|
||||||
dataset_ids = [
|
dataset_ids = [
|
||||||
|
"pusht_image",
|
||||||
|
"xarm_lift_medium_image",
|
||||||
|
"xarm_lift_medium_replay_image",
|
||||||
|
"xarm_push_medium_image",
|
||||||
|
"xarm_push_medium_replay_image",
|
||||||
|
"aloha_sim_insertion_human_image",
|
||||||
|
"aloha_sim_insertion_scripted_image",
|
||||||
|
"aloha_sim_transfer_cube_human_image",
|
||||||
|
"aloha_sim_transfer_cube_scripted_image",
|
||||||
"pusht",
|
"pusht",
|
||||||
"xarm_lift_medium",
|
"xarm_lift_medium",
|
||||||
"xarm_lift_medium_replay",
|
"xarm_lift_medium_replay",
|
||||||
"xarm_push_medium",
|
"xarm_push_medium",
|
||||||
"xarm_push_medium_replay",
|
"xarm_push_medium_replay",
|
||||||
|
"aloha_sim_insertion_human",
|
||||||
|
"aloha_sim_insertion_scripted",
|
||||||
|
"aloha_sim_transfer_cube_human",
|
||||||
|
"aloha_sim_transfer_cube_scripted",
|
||||||
"aloha_mobile_cabinet",
|
"aloha_mobile_cabinet",
|
||||||
"aloha_mobile_chair",
|
"aloha_mobile_chair",
|
||||||
"aloha_mobile_elevator",
|
"aloha_mobile_elevator",
|
||||||
"aloha_mobile_shrimp",
|
"aloha_mobile_shrimp",
|
||||||
"aloha_mobile_wash_pan",
|
"aloha_mobile_wash_pan",
|
||||||
"aloha_mobile_wipe_wine",
|
"aloha_mobile_wipe_wine",
|
||||||
"aloha_sim_insertion_human",
|
|
||||||
"aloha_sim_insertion_scripted",
|
|
||||||
"aloha_sim_transfer_cube_human",
|
|
||||||
"aloha_sim_transfer_cube_scripted",
|
|
||||||
"aloha_static_battery",
|
"aloha_static_battery",
|
||||||
"aloha_static_candy",
|
"aloha_static_candy",
|
||||||
"aloha_static_coffee",
|
"aloha_static_coffee",
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
# imagecodecs/numcodecs.py
|
# imagecodecs/numcodecs.py
|
||||||
|
|
||||||
# Copyright (c) 2021-2022, Christoph Gohlke
|
# Copyright (c) 2021-2022, Christoph Gohlke
|
||||||
|
|
|
@ -1,8 +1,23 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
|
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import gc
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -64,10 +79,8 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
episode_data_index = {"from": [], "to": []}
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
id_from = 0
|
id_from = 0
|
||||||
|
for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)):
|
||||||
for ep_path in tqdm.tqdm(hdf5_files, total=len(hdf5_files)):
|
|
||||||
with h5py.File(ep_path, "r") as ep:
|
with h5py.File(ep_path, "r") as ep:
|
||||||
ep_idx = int(re.search(r"episode_(\d+)", ep_path.name).group(1))
|
|
||||||
num_frames = ep["/action"].shape[0]
|
num_frames = ep["/action"].shape[0]
|
||||||
|
|
||||||
# last step of demonstration is considered done
|
# last step of demonstration is considered done
|
||||||
|
@ -76,6 +89,10 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
|
|
||||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||||
action = torch.from_numpy(ep["/action"][:])
|
action = torch.from_numpy(ep["/action"][:])
|
||||||
|
if "/observations/qvel" in ep:
|
||||||
|
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||||
|
if "/observations/effort" in ep:
|
||||||
|
effort = torch.from_numpy(ep["/observations/effort"][:])
|
||||||
|
|
||||||
ep_dict = {}
|
ep_dict = {}
|
||||||
|
|
||||||
|
@ -116,6 +133,10 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||||
|
|
||||||
ep_dict["observation.state"] = state
|
ep_dict["observation.state"] = state
|
||||||
|
if "/observations/velocity" in ep:
|
||||||
|
ep_dict["observation.velocity"] = velocity
|
||||||
|
if "/observations/effort" in ep:
|
||||||
|
ep_dict["observation.effort"] = effort
|
||||||
ep_dict["action"] = action
|
ep_dict["action"] = action
|
||||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||||
|
@ -131,6 +152,8 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
|
|
||||||
id_from += num_frames
|
id_from += num_frames
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
# process first episode only
|
# process first episode only
|
||||||
if debug:
|
if debug:
|
||||||
break
|
break
|
||||||
|
@ -152,6 +175,14 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
features["observation.state"] = Sequence(
|
features["observation.state"] = Sequence(
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
)
|
)
|
||||||
|
if "observation.velocity" in data_dict:
|
||||||
|
features["observation.velocity"] = Sequence(
|
||||||
|
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
if "observation.effort" in data_dict:
|
||||||
|
features["observation.effort"] = Sequence(
|
||||||
|
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
features["action"] = Sequence(
|
features["action"] = Sequence(
|
||||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
|
"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface"""
|
"""Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
|
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
|
|
|
@ -1,5 +1,22 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
|
@ -64,7 +81,23 @@ def hf_transform_to_torch(items_dict):
|
||||||
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
||||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||||
if root is not None:
|
if root is not None:
|
||||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / split))
|
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
||||||
|
# TODO(rcadene): clean this which enables getting a subset of dataset
|
||||||
|
if split != "train":
|
||||||
|
if "%" in split:
|
||||||
|
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
|
||||||
|
match_from = re.search(r"train\[(\d+):\]", split)
|
||||||
|
match_to = re.search(r"train\[:(\d+)\]", split)
|
||||||
|
if match_from:
|
||||||
|
from_frame_index = int(match_from.group(1))
|
||||||
|
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
|
||||||
|
elif match_to:
|
||||||
|
to_frame_index = int(match_to.group(1))
|
||||||
|
hf_dataset = hf_dataset.select(range(to_frame_index))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
@ -230,6 +263,84 @@ def load_previous_and_future_frames(
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
|
||||||
|
- "from": A tensor containing the starting index of each episode.
|
||||||
|
- "to": A tensor containing the ending index of each episode.
|
||||||
|
"""
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
|
current_episode = None
|
||||||
|
"""
|
||||||
|
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||||
|
For instance, the following is a valid episode_index:
|
||||||
|
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||||
|
|
||||||
|
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||||
|
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||||
|
{
|
||||||
|
"from": [0, 3, 7],
|
||||||
|
"to": [3, 7, 12]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if len(hf_dataset) == 0:
|
||||||
|
episode_data_index = {
|
||||||
|
"from": torch.tensor([]),
|
||||||
|
"to": torch.tensor([]),
|
||||||
|
}
|
||||||
|
return episode_data_index
|
||||||
|
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||||
|
if episode_idx != current_episode:
|
||||||
|
# We encountered a new episode, so we append its starting location to the "from" list
|
||||||
|
episode_data_index["from"].append(idx)
|
||||||
|
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
|
||||||
|
if current_episode is not None:
|
||||||
|
episode_data_index["to"].append(idx)
|
||||||
|
# Let's keep track of the current episode index
|
||||||
|
current_episode = episode_idx
|
||||||
|
else:
|
||||||
|
# We are still in the same episode, so there is nothing for us to do here
|
||||||
|
pass
|
||||||
|
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
|
||||||
|
episode_data_index["to"].append(idx + 1)
|
||||||
|
|
||||||
|
for k in ["from", "to"]:
|
||||||
|
episode_data_index[k] = torch.tensor(episode_data_index[k])
|
||||||
|
|
||||||
|
return episode_data_index
|
||||||
|
|
||||||
|
|
||||||
|
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
||||||
|
"""
|
||||||
|
Reset the `episode_index` of the provided HuggingFace Dataset.
|
||||||
|
|
||||||
|
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
|
||||||
|
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
|
||||||
|
|
||||||
|
This brings the `episode_index` to the required format.
|
||||||
|
"""
|
||||||
|
if len(hf_dataset) == 0:
|
||||||
|
return hf_dataset
|
||||||
|
unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist()
|
||||||
|
episode_idx_to_reset_idx_mapping = {
|
||||||
|
ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs)
|
||||||
|
}
|
||||||
|
|
||||||
|
def modify_ep_idx_func(example):
|
||||||
|
example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()]
|
||||||
|
return example
|
||||||
|
|
||||||
|
hf_dataset = hf_dataset.map(modify_ep_idx_func)
|
||||||
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def cycle(iterable):
|
def cycle(iterable):
|
||||||
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import subprocess
|
import subprocess
|
||||||
import warnings
|
import warnings
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
# TODO(rcadene, alexander-soare): clean this file
|
# TODO(rcadene, alexander-soare): clean this file
|
||||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
|
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
|
||||||
|
|
||||||
|
@ -17,9 +32,10 @@ def log_output_dir(out_dir):
|
||||||
|
|
||||||
|
|
||||||
def cfg_to_group(cfg, return_list=False):
|
def cfg_to_group(cfg, return_list=False):
|
||||||
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
"""Return a group name for logging. Optionally returns group name as list."""
|
||||||
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
|
||||||
lst = [
|
lst = [
|
||||||
|
f"policy:{cfg.policy.name}",
|
||||||
|
f"dataset:{cfg.dataset_repo_id}",
|
||||||
f"env:{cfg.env.name}",
|
f"env:{cfg.env.name}",
|
||||||
f"seed:{cfg.seed}",
|
f"seed:{cfg.seed}",
|
||||||
]
|
]
|
||||||
|
@ -81,9 +97,9 @@ class Logger:
|
||||||
# Also save the full Hydra config for the env configuration.
|
# Also save the full Hydra config for the env configuration.
|
||||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||||
if self._wandb and not self._disable_wandb_artifact:
|
if self._wandb and not self._disable_wandb_artifact:
|
||||||
# note wandb artifact does not accept ":" in its name
|
# note wandb artifact does not accept ":" or "/" in its name
|
||||||
artifact = self._wandb.Artifact(
|
artifact = self._wandb.Artifact(
|
||||||
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
||||||
type="model",
|
type="model",
|
||||||
)
|
)
|
||||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||||
|
@ -93,9 +109,10 @@ class Logger:
|
||||||
self._buffer_dir.mkdir(parents=True, exist_ok=True)
|
self._buffer_dir.mkdir(parents=True, exist_ok=True)
|
||||||
fp = self._buffer_dir / f"{str(identifier)}.pkl"
|
fp = self._buffer_dir / f"{str(identifier)}.pkl"
|
||||||
buffer.save(fp)
|
buffer.save(fp)
|
||||||
if self._wandb:
|
if self._wandb and not self._disable_wandb_artifact:
|
||||||
|
# note wandb artifact does not accept ":" or "/" in its name
|
||||||
artifact = self._wandb.Artifact(
|
artifact = self._wandb.Artifact(
|
||||||
self._group + "-" + str(self._seed) + "-" + str(identifier),
|
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
||||||
type="buffer",
|
type="buffer",
|
||||||
)
|
)
|
||||||
artifact.add_file(fp)
|
artifact.add_file(fp)
|
||||||
|
@ -113,6 +130,11 @@ class Logger:
|
||||||
assert mode in {"train", "eval"}
|
assert mode in {"train", "eval"}
|
||||||
if self._wandb is not None:
|
if self._wandb is not None:
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
|
if not isinstance(v, (int, float, str)):
|
||||||
|
logging.warning(
|
||||||
|
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||||
|
)
|
||||||
|
continue
|
||||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||||
|
|
||||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,8 +66,12 @@ class ACTConfig:
|
||||||
documentation in the policy class).
|
documentation in the policy class).
|
||||||
latent_dim: The VAE's latent dimension.
|
latent_dim: The VAE's latent dimension.
|
||||||
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
||||||
use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given
|
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
|
||||||
environment step.
|
actions for a given time step over multiple policy invocations. Updates are calculated as:
|
||||||
|
x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
|
||||||
|
parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our
|
||||||
|
formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we
|
||||||
|
require `n_action_steps == 1` (since we need to query the policy every step anyway).
|
||||||
dropout: Dropout to use in the transformer layers (see code for details).
|
dropout: Dropout to use in the transformer layers (see code for details).
|
||||||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||||||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||||
|
@ -100,6 +119,9 @@ class ACTConfig:
|
||||||
dim_feedforward: int = 3200
|
dim_feedforward: int = 3200
|
||||||
feedforward_activation: str = "relu"
|
feedforward_activation: str = "relu"
|
||||||
n_encoder_layers: int = 4
|
n_encoder_layers: int = 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
n_decoder_layers: int = 1
|
n_decoder_layers: int = 1
|
||||||
# VAE.
|
# VAE.
|
||||||
use_vae: bool = True
|
use_vae: bool = True
|
||||||
|
@ -107,7 +129,7 @@ class ACTConfig:
|
||||||
n_vae_encoder_layers: int = 4
|
n_vae_encoder_layers: int = 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
use_temporal_aggregation: bool = False
|
temporal_ensemble_momentum: float | None = None
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: float = 0.1
|
dropout: float = 0.1
|
||||||
|
@ -119,8 +141,11 @@ class ACTConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
)
|
)
|
||||||
if self.use_temporal_aggregation:
|
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
|
||||||
raise NotImplementedError("Temporal aggregation is not yet implemented.")
|
raise NotImplementedError(
|
||||||
|
"`n_action_steps` must be 1 when using temporal ensembling. This is "
|
||||||
|
"because the policy needs to be queried every step to compute the ensembled action."
|
||||||
|
)
|
||||||
if self.n_action_steps > self.chunk_size:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||||
|
@ -130,10 +155,3 @@ class ACTConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||||
)
|
)
|
||||||
# Check that there is only one image.
|
|
||||||
# TODO(alexander-soare): generalize this to multiple images.
|
|
||||||
if (
|
|
||||||
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
|
|
||||||
or "observation.images.top" not in self.input_shapes
|
|
||||||
):
|
|
||||||
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')
|
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Action Chunking Transformer Policy
|
"""Action Chunking Transformer Policy
|
||||||
|
|
||||||
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
||||||
|
@ -46,7 +61,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if config is None:
|
if config is None:
|
||||||
config = ACTConfig()
|
config = ACTConfig()
|
||||||
self.config = config
|
self.config: ACTConfig = config
|
||||||
|
|
||||||
self.normalize_inputs = Normalize(
|
self.normalize_inputs = Normalize(
|
||||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
|
@ -56,11 +72,18 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
self.unnormalize_outputs = Unnormalize(
|
self.unnormalize_outputs = Unnormalize(
|
||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model = ACT(config)
|
self.model = ACT(config)
|
||||||
|
|
||||||
|
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""This should be called whenever the environment is reset."""
|
"""This should be called whenever the environment is reset."""
|
||||||
if self.config.n_action_steps is not None:
|
if self.config.temporal_ensemble_momentum is not None:
|
||||||
|
self._ensembled_actions = None
|
||||||
|
else:
|
||||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
|
@ -71,37 +94,56 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||||
queue is empty.
|
queue is empty.
|
||||||
"""
|
"""
|
||||||
assert "observation.images.top" in batch
|
|
||||||
assert "observation.state" in batch
|
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
self._stack_images(batch)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
|
|
||||||
|
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
|
||||||
|
# the first action.
|
||||||
|
if self.config.temporal_ensemble_momentum is not None:
|
||||||
|
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
||||||
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
if self._ensembled_actions is None:
|
||||||
|
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||||
|
# time step of the episode.
|
||||||
|
self._ensembled_actions = actions.clone()
|
||||||
|
else:
|
||||||
|
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||||
|
# the EMA update for those entries.
|
||||||
|
alpha = self.config.temporal_ensemble_momentum
|
||||||
|
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
|
||||||
|
# The last action, which has no prior moving average, needs to get concatenated onto the end.
|
||||||
|
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
|
||||||
|
# "Consume" the first action.
|
||||||
|
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
|
||||||
|
return action
|
||||||
|
|
||||||
|
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||||
|
# querying the policy.
|
||||||
if len(self._action_queue) == 0:
|
if len(self._action_queue) == 0:
|
||||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
|
||||||
actions = self.model(batch)[0][: self.config.n_action_steps]
|
|
||||||
|
|
||||||
# TODO(rcadene): make _forward return output dictionary?
|
# TODO(rcadene): make _forward return output dictionary?
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
|
||||||
|
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||||
|
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||||
self._action_queue.extend(actions.transpose(0, 1))
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
self._stack_images(batch)
|
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||||
|
|
||||||
l1_loss = (
|
l1_loss = (
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
).mean()
|
).mean()
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss}
|
loss_dict = {"l1_loss": l1_loss.item()}
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
|
@ -110,28 +152,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
mean_kld = (
|
mean_kld = (
|
||||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||||
)
|
)
|
||||||
loss_dict["kld_loss"] = mean_kld
|
loss_dict["kld_loss"] = mean_kld.item()
|
||||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||||
else:
|
else:
|
||||||
loss_dict["loss"] = l1_loss
|
loss_dict["loss"] = l1_loss
|
||||||
|
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
||||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
|
||||||
|
|
||||||
This function expects `batch` to have (at least):
|
|
||||||
{
|
|
||||||
"observation.state": (B, state_dim) batch of robot states.
|
|
||||||
"observation.images.{name}": (B, C, H, W) tensor of images.
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
# Stack images in the order dictated by input_shapes.
|
|
||||||
batch["observation.images"] = torch.stack(
|
|
||||||
[batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")],
|
|
||||||
dim=-4,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ACT(nn.Module):
|
class ACT(nn.Module):
|
||||||
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
||||||
|
@ -161,10 +188,10 @@ class ACT(nn.Module):
|
||||||
│ encoder │ │ │ │Transf.│ │
|
│ encoder │ │ │ │Transf.│ │
|
||||||
│ │ │ │ │encoder│ │
|
│ │ │ │ │encoder│ │
|
||||||
└───▲─────┘ │ │ │ │ │
|
└───▲─────┘ │ │ │ │ │
|
||||||
│ │ │ └───▲───┘ │
|
│ │ │ └▲──▲─▲─┘ │
|
||||||
│ │ │ │ │
|
│ │ │ │ │ │ │
|
||||||
inputs └─────┼─────┘ │
|
inputs └─────┼──┘ │ image emb. │
|
||||||
│ │
|
│ state emb. │
|
||||||
└───────────────────────┘
|
└───────────────────────┘
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -306,18 +333,18 @@ class ACT(nn.Module):
|
||||||
all_cam_features.append(cam_features)
|
all_cam_features.append(cam_features)
|
||||||
all_cam_pos_embeds.append(cam_pos_embed)
|
all_cam_pos_embeds.append(cam_pos_embed)
|
||||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
|
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
|
||||||
encoder_in = torch.cat(all_cam_features, axis=3)
|
encoder_in = torch.cat(all_cam_features, axis=-1)
|
||||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
|
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||||
|
|
||||||
# Get positional embeddings for robot state and latent.
|
# Get positional embeddings for robot state and latent.
|
||||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
|
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||||
latent_embed = self.encoder_latent_input_proj(latent_sample)
|
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||||
|
|
||||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||||
encoder_in = torch.cat(
|
encoder_in = torch.cat(
|
||||||
[
|
[
|
||||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
torch.stack([latent_embed, robot_state_embed], axis=0),
|
||||||
encoder_in.flatten(2).permute(2, 0, 1),
|
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
pos_embed = torch.cat(
|
pos_embed = torch.cat(
|
||||||
|
|
|
@ -1,3 +1,19 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,6 +67,7 @@ class DiffusionConfig:
|
||||||
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
|
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
|
||||||
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
||||||
modulation.
|
modulation.
|
||||||
|
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
|
||||||
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
||||||
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
||||||
beta_start: Beta value for the first forward-diffusion step.
|
beta_start: Beta value for the first forward-diffusion step.
|
||||||
|
@ -110,6 +127,7 @@ class DiffusionConfig:
|
||||||
diffusion_step_embed_dim: int = 128
|
diffusion_step_embed_dim: int = 128
|
||||||
use_film_scale_modulation: bool = True
|
use_film_scale_modulation: bool = True
|
||||||
# Noise scheduler.
|
# Noise scheduler.
|
||||||
|
noise_scheduler_type: str = "DDPM"
|
||||||
num_train_timesteps: int = 100
|
num_train_timesteps: int = 100
|
||||||
beta_schedule: str = "squaredcos_cap_v2"
|
beta_schedule: str = "squaredcos_cap_v2"
|
||||||
beta_start: float = 0.0001
|
beta_start: float = 0.0001
|
||||||
|
@ -130,17 +148,30 @@ class DiffusionConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
)
|
)
|
||||||
|
# There should only be one image key.
|
||||||
|
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||||
|
if len(image_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||||
|
)
|
||||||
|
image_key = next(iter(image_keys))
|
||||||
if (
|
if (
|
||||||
self.crop_shape[0] > self.input_shapes["observation.image"][1]
|
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||||
or self.crop_shape[1] > self.input_shapes["observation.image"][2]
|
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} '
|
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||||
f'for `crop_shape` and {self.input_shapes["observation.image"]} for '
|
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||||
'`input_shapes["observation.image"]`.'
|
"`input_shapes[{image_key}]`."
|
||||||
)
|
)
|
||||||
supported_prediction_types = ["epsilon", "sample"]
|
supported_prediction_types = ["epsilon", "sample"]
|
||||||
if self.prediction_type not in supported_prediction_types:
|
if self.prediction_type not in supported_prediction_types:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
||||||
)
|
)
|
||||||
|
supported_noise_schedulers = ["DDPM", "DDIM"]
|
||||||
|
if self.noise_scheduler_type not in supported_noise_schedulers:
|
||||||
|
raise ValueError(
|
||||||
|
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
|
||||||
|
f"Got {self.noise_scheduler_type}."
|
||||||
|
)
|
||||||
|
|
|
@ -1,8 +1,24 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||||
|
|
||||||
TODO(alexander-soare):
|
TODO(alexander-soare):
|
||||||
- Remove reliance on Robomimic for SpatialSoftmax.
|
|
||||||
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
||||||
|
- Make compatible with multiple image keys.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
@ -10,12 +26,13 @@ from collections import deque
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
import torchvision
|
import torchvision
|
||||||
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
from robomimic.models.base_nets import SpatialSoftmax
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
@ -66,10 +83,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
self.diffusion = DiffusionModel(config)
|
self.diffusion = DiffusionModel(config)
|
||||||
|
|
||||||
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
|
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||||
|
if len(image_keys) != 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||||
|
)
|
||||||
|
self.input_image_key = image_keys[0]
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||||
Clear observation and action queues. Should be called on `env.reset()`
|
|
||||||
"""
|
|
||||||
self._queues = {
|
self._queues = {
|
||||||
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
||||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||||
|
@ -98,16 +123,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
"""
|
"""
|
||||||
assert "observation.image" in batch
|
|
||||||
assert "observation.state" in batch
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
if len(self._queues["action"]) == 0:
|
if len(self._queues["action"]) == 0:
|
||||||
# stack n latest observations from the queue
|
# stack n latest observations from the queue
|
||||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||||
actions = self.diffusion.generate_actions(batch)
|
actions = self.diffusion.generate_actions(batch)
|
||||||
|
|
||||||
# TODO(rcadene): make above methods return output dictionary?
|
# TODO(rcadene): make above methods return output dictionary?
|
||||||
|
@ -121,11 +144,25 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.diffusion.compute_loss(batch)
|
||||||
return {"loss": loss}
|
return {"loss": loss}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||||
|
"""
|
||||||
|
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
||||||
|
to the scheduler.
|
||||||
|
"""
|
||||||
|
if name == "DDPM":
|
||||||
|
return DDPMScheduler(**kwargs)
|
||||||
|
elif name == "DDIM":
|
||||||
|
return DDIMScheduler(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||||
|
|
||||||
|
|
||||||
class DiffusionModel(nn.Module):
|
class DiffusionModel(nn.Module):
|
||||||
def __init__(self, config: DiffusionConfig):
|
def __init__(self, config: DiffusionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -138,12 +175,12 @@ class DiffusionModel(nn.Module):
|
||||||
* config.n_obs_steps,
|
* config.n_obs_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.noise_scheduler = DDPMScheduler(
|
self.noise_scheduler = _make_noise_scheduler(
|
||||||
|
config.noise_scheduler_type,
|
||||||
num_train_timesteps=config.num_train_timesteps,
|
num_train_timesteps=config.num_train_timesteps,
|
||||||
beta_start=config.beta_start,
|
beta_start=config.beta_start,
|
||||||
beta_end=config.beta_end,
|
beta_end=config.beta_end,
|
||||||
beta_schedule=config.beta_schedule,
|
beta_schedule=config.beta_schedule,
|
||||||
variance_type="fixed_small",
|
|
||||||
clip_sample=config.clip_sample,
|
clip_sample=config.clip_sample,
|
||||||
clip_sample_range=config.clip_sample_range,
|
clip_sample_range=config.clip_sample_range,
|
||||||
prediction_type=config.prediction_type,
|
prediction_type=config.prediction_type,
|
||||||
|
@ -185,13 +222,12 @@ class DiffusionModel(nn.Module):
|
||||||
|
|
||||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
This function expects `batch` to have (at least):
|
This function expects `batch` to have:
|
||||||
{
|
{
|
||||||
"observation.state": (B, n_obs_steps, state_dim)
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
"observation.image": (B, n_obs_steps, C, H, W)
|
"observation.image": (B, n_obs_steps, C, H, W)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
assert n_obs_steps == self.config.n_obs_steps
|
assert n_obs_steps == self.config.n_obs_steps
|
||||||
|
|
||||||
|
@ -275,6 +311,77 @@ class DiffusionModel(nn.Module):
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialSoftmax(nn.Module):
|
||||||
|
"""
|
||||||
|
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||||
|
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
|
||||||
|
|
||||||
|
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
|
||||||
|
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
|
||||||
|
|
||||||
|
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
|
||||||
|
-----------------------------------------------------
|
||||||
|
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
|
||||||
|
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
|
||||||
|
| ... | ... | ... | ... |
|
||||||
|
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
|
||||||
|
-----------------------------------------------------
|
||||||
|
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
|
||||||
|
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
|
||||||
|
|
||||||
|
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
|
||||||
|
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
|
||||||
|
linear mapping (in_channels, H, W) -> (num_kp, H, W).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_shape, num_kp=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_shape (list): (C, H, W) input feature map shape.
|
||||||
|
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert len(input_shape) == 3
|
||||||
|
self._in_c, self._in_h, self._in_w = input_shape
|
||||||
|
|
||||||
|
if num_kp is not None:
|
||||||
|
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||||
|
self._out_c = num_kp
|
||||||
|
else:
|
||||||
|
self.nets = None
|
||||||
|
self._out_c = self._in_c
|
||||||
|
|
||||||
|
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||||
|
# and causes a small degradation in pc_success of pre-trained models.
|
||||||
|
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||||
|
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||||
|
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||||
|
# register as buffer so it's moved to the correct device.
|
||||||
|
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
|
||||||
|
|
||||||
|
def forward(self, features: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
features: (B, C, H, W) input feature maps.
|
||||||
|
Returns:
|
||||||
|
(B, K, 2) image-space coordinates of keypoints.
|
||||||
|
"""
|
||||||
|
if self.nets is not None:
|
||||||
|
features = self.nets(features)
|
||||||
|
|
||||||
|
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||||
|
features = features.reshape(-1, self._in_h * self._in_w)
|
||||||
|
# 2d softmax normalization
|
||||||
|
attention = F.softmax(features, dim=-1)
|
||||||
|
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
|
||||||
|
expected_xy = attention @ self.pos_grid
|
||||||
|
# reshape to [B, K, 2]
|
||||||
|
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||||
|
|
||||||
|
return feature_keypoints
|
||||||
|
|
||||||
|
|
||||||
class DiffusionRgbEncoder(nn.Module):
|
class DiffusionRgbEncoder(nn.Module):
|
||||||
"""Encoder an RGB image into a 1D feature vector.
|
"""Encoder an RGB image into a 1D feature vector.
|
||||||
|
|
||||||
|
@ -315,11 +422,16 @@ class DiffusionRgbEncoder(nn.Module):
|
||||||
|
|
||||||
# Set up pooling and final layers.
|
# Set up pooling and final layers.
|
||||||
# Use a dry run to get the feature map shape.
|
# Use a dry run to get the feature map shape.
|
||||||
|
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||||
|
# use the height and width from `config.crop_shape`.
|
||||||
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
|
assert len(image_keys) == 1
|
||||||
|
image_key = image_keys[0]
|
||||||
|
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
feat_map_shape = tuple(
|
dummy_feature_map = self.backbone(dummy_input)
|
||||||
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
|
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||||
)
|
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
|
||||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
|
|
|
@ -1,4 +1,20 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
|
@ -8,9 +24,10 @@ from lerobot.common.utils.utils import get_safe_torch_device
|
||||||
|
|
||||||
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||||
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
||||||
assert set(hydra_cfg.policy).issuperset(
|
if not set(hydra_cfg.policy).issuperset(expected_kwargs):
|
||||||
expected_kwargs
|
logging.warning(
|
||||||
), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
||||||
|
)
|
||||||
policy_cfg = policy_cfg_class(
|
policy_cfg = policy_cfg_class(
|
||||||
**{
|
**{
|
||||||
k: v
|
k: v
|
||||||
|
@ -62,11 +79,18 @@ def make_policy(
|
||||||
|
|
||||||
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
||||||
|
|
||||||
if pretrained_policy_name_or_path is None:
|
|
||||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||||
|
if pretrained_policy_name_or_path is None:
|
||||||
|
# Make a fresh policy.
|
||||||
policy = policy_cls(policy_cfg, dataset_stats)
|
policy = policy_cls(policy_cfg, dataset_stats)
|
||||||
else:
|
else:
|
||||||
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
|
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||||
|
# hyperparameters that we want to vary).
|
||||||
|
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained
|
||||||
|
# weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should
|
||||||
|
# make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274.
|
||||||
|
policy = policy_cls(policy_cfg)
|
||||||
|
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
|
||||||
|
|
||||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""A protocol that all policies should follow.
|
"""A protocol that all policies should follow.
|
||||||
|
|
||||||
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
|
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
|
||||||
|
@ -38,7 +53,8 @@ class Policy(Protocol):
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||||
"""Run the batch through the model and compute the loss for training or validation.
|
"""Run the batch through the model and compute the loss for training or validation.
|
||||||
|
|
||||||
Returns a dictionary with "loss" and maybe other information.
|
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
||||||
|
other items should be logging-friendly, native Python types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def select_action(self, batch: dict[str, Tensor]):
|
def select_action(self, batch: dict[str, Tensor]):
|
||||||
|
|
|
@ -1,3 +1,19 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,7 +63,7 @@ class TDMPCConfig:
|
||||||
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
||||||
elites, when updating the gaussian parameters for CEM.
|
elites, when updating the gaussian parameters for CEM.
|
||||||
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
|
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
|
||||||
paramters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
|
parameters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
|
||||||
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
|
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
|
||||||
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
|
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
|
||||||
is applied. Note that the input images are assumed to be square for this augmentation.
|
is applied. Note that the input images are assumed to be square for this augmentation.
|
||||||
|
@ -131,12 +147,18 @@ class TDMPCConfig:
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]:
|
# There should only be one image key.
|
||||||
|
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||||
|
if len(image_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||||
|
)
|
||||||
|
image_key = next(iter(image_keys))
|
||||||
|
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||||
# augmentation. It should be able to be removed.
|
# augmentation. It should be able to be removed.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only square images are handled now. Got image shape "
|
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||||
f"{self.input_shapes['observation.image']}."
|
|
||||||
)
|
)
|
||||||
if self.n_gaussian_samples <= 0:
|
if self.n_gaussian_samples <= 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -1,3 +1,19 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Implementation of Finetuning Offline World Models in the Real World.
|
"""Implementation of Finetuning Offline World Models in the Real World.
|
||||||
|
|
||||||
The comments in this code may sometimes refer to these references:
|
The comments in this code may sometimes refer to these references:
|
||||||
|
@ -96,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
def save(self, fp):
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
"""Save state dict of TOLD model to filepath."""
|
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||||
torch.save(self.state_dict(), fp)
|
assert len(image_keys) == 1
|
||||||
|
self.input_image_key = image_keys[0]
|
||||||
|
|
||||||
def load(self, fp):
|
self.reset()
|
||||||
"""Load a saved state dict from filepath into current agent."""
|
|
||||||
self.load_state_dict(torch.load(fp))
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
|
@ -121,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]):
|
def select_action(self, batch: dict[str, Tensor]):
|
||||||
"""Select a single action given environment observations."""
|
"""Select a single action given environment observations."""
|
||||||
assert "observation.image" in batch
|
|
||||||
assert "observation.state" in batch
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
|
@ -303,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
|
|
||||||
info = {}
|
info = {}
|
||||||
|
|
||||||
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
|
|
||||||
batch_size = batch["index"].shape[0]
|
|
||||||
|
|
||||||
# (b, t) -> (t, b)
|
# (b, t) -> (t, b)
|
||||||
for key in batch:
|
for key in batch:
|
||||||
if batch[key].ndim > 1:
|
if batch[key].ndim > 1:
|
||||||
|
@ -337,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
# Run latent rollout using the latent dynamics model and policy model.
|
# Run latent rollout using the latent dynamics model and policy model.
|
||||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||||
# gives us a next `z`.
|
# gives us a next `z`.
|
||||||
|
batch_size = batch["index"].shape[0]
|
||||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||||
z_preds[0] = self.model.encode(current_observation)
|
z_preds[0] = self.model.encode(current_observation)
|
||||||
reward_preds = torch.empty_like(reward, device=device)
|
reward_preds = torch.empty_like(reward, device=device)
|
||||||
|
|
|
@ -1,9 +1,28 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
def populate_queues(queues, batch):
|
def populate_queues(queues, batch):
|
||||||
for key in batch:
|
for key in batch:
|
||||||
|
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
|
||||||
|
# queues have the keys they want).
|
||||||
|
if key not in queues:
|
||||||
|
continue
|
||||||
if len(queues[key]) != queues[key].maxlen:
|
if len(queues[key]) != queues[key].maxlen:
|
||||||
# initialize by copying the first observation several times until the queue is full
|
# initialize by copying the first observation several times until the queue is full
|
||||||
while len(queues[key]) != queues[key].maxlen:
|
while len(queues[key]) != queues[key].maxlen:
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import imageio
|
import imageio
|
||||||
|
|
|
@ -1,8 +1,25 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import random
|
import random
|
||||||
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -39,6 +56,31 @@ def set_global_seed(seed):
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||||
|
"""Set the seed when entering a context, and restore the prior random state at exit.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
```
|
||||||
|
a = random.random() # produces some random number
|
||||||
|
with seeded_context(1337):
|
||||||
|
b = random.random() # produces some other random number
|
||||||
|
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
random_state = random.getstate()
|
||||||
|
np_random_state = np.random.get_state()
|
||||||
|
torch_random_state = torch.random.get_rng_state()
|
||||||
|
torch_cuda_random_state = torch.cuda.random.get_rng_state()
|
||||||
|
set_global_seed(seed)
|
||||||
|
yield None
|
||||||
|
random.setstate(random_state)
|
||||||
|
np.random.set_state(np_random_state)
|
||||||
|
torch.random.set_rng_state(torch_random_state)
|
||||||
|
torch.cuda.random.set_rng_state(torch_cuda_random_state)
|
||||||
|
|
||||||
|
|
||||||
def init_logging():
|
def init_logging():
|
||||||
def custom_format(record):
|
def custom_format(record):
|
||||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
|
@ -10,6 +10,9 @@ hydra:
|
||||||
name: default
|
name: default
|
||||||
|
|
||||||
device: cuda # cpu
|
device: cuda # cpu
|
||||||
|
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||||
|
# automatic gradient scaling is used.
|
||||||
|
use_amp: false
|
||||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||||
# AND for the evaluation environments.
|
# AND for the evaluation environments.
|
||||||
seed: ???
|
seed: ???
|
||||||
|
@ -17,6 +20,7 @@ dataset_repo_id: lerobot/pusht
|
||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: ???
|
offline_steps: ???
|
||||||
|
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
|
||||||
online_steps: ???
|
online_steps: ???
|
||||||
online_steps_between_rollouts: ???
|
online_steps_between_rollouts: ???
|
||||||
online_sampling_ratio: 0.5
|
online_sampling_ratio: 0.5
|
||||||
|
@ -35,7 +39,7 @@ eval:
|
||||||
use_async_envs: false
|
use_async_envs: false
|
||||||
|
|
||||||
wandb:
|
wandb:
|
||||||
enable: true
|
enable: false
|
||||||
# Set to true to disable saving an artifact despite save_model == True
|
# Set to true to disable saving an artifact despite save_model == True
|
||||||
disable_artifact: false
|
disable_artifact: false
|
||||||
project: lerobot
|
project: lerobot
|
||||||
|
|
|
@ -3,6 +3,12 @@
|
||||||
seed: 1000
|
seed: 1000
|
||||||
dataset_repo_id: lerobot/aloha_sim_insertion_human
|
dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
observation.images.top:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: 80000
|
offline_steps: 80000
|
||||||
online_steps: 0
|
online_steps: 0
|
||||||
|
@ -18,12 +24,6 @@ training:
|
||||||
grad_clip_norm: 10
|
grad_clip_norm: 10
|
||||||
online_steps_between_rollouts: 1
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
override_dataset_stats:
|
|
||||||
observation.images.top:
|
|
||||||
# stats from imagenet, since we use a pretrained vision model
|
|
||||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
|
||||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
|
||||||
|
|
||||||
delta_timestamps:
|
delta_timestamps:
|
||||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
||||||
|
@ -66,6 +66,9 @@ policy:
|
||||||
dim_feedforward: 3200
|
dim_feedforward: 3200
|
||||||
feedforward_activation: relu
|
feedforward_activation: relu
|
||||||
n_encoder_layers: 4
|
n_encoder_layers: 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
n_decoder_layers: 1
|
n_decoder_layers: 1
|
||||||
# VAE.
|
# VAE.
|
||||||
use_vae: true
|
use_vae: true
|
||||||
|
@ -73,7 +76,7 @@ policy:
|
||||||
n_vae_encoder_layers: 4
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
use_temporal_aggregation: false
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
|
|
|
@ -7,6 +7,20 @@
|
||||||
seed: 100000
|
seed: 100000
|
||||||
dataset_repo_id: lerobot/pusht
|
dataset_repo_id: lerobot/pusht
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
||||||
|
observation.image:
|
||||||
|
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||||
|
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||||
|
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
|
||||||
|
# from the original codebase, but we should remove these and train our own pretrained model
|
||||||
|
observation.state:
|
||||||
|
min: [13.456424, 32.938293]
|
||||||
|
max: [496.14618, 510.9579]
|
||||||
|
action:
|
||||||
|
min: [12.0, 25.0]
|
||||||
|
max: [511.0, 511.0]
|
||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: 200000
|
offline_steps: 200000
|
||||||
online_steps: 0
|
online_steps: 0
|
||||||
|
@ -34,20 +48,6 @@ eval:
|
||||||
n_episodes: 50
|
n_episodes: 50
|
||||||
batch_size: 50
|
batch_size: 50
|
||||||
|
|
||||||
override_dataset_stats:
|
|
||||||
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
|
||||||
observation.image:
|
|
||||||
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
|
||||||
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
|
||||||
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
|
|
||||||
# from the original codebase, but we should remove these and train our own pretrained model
|
|
||||||
observation.state:
|
|
||||||
min: [13.456424, 32.938293]
|
|
||||||
max: [496.14618, 510.9579]
|
|
||||||
action:
|
|
||||||
min: [12.0, 25.0]
|
|
||||||
max: [511.0, 511.0]
|
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
name: diffusion
|
name: diffusion
|
||||||
|
|
||||||
|
@ -85,6 +85,7 @@ policy:
|
||||||
diffusion_step_embed_dim: 128
|
diffusion_step_embed_dim: 128
|
||||||
use_film_scale_modulation: True
|
use_film_scale_modulation: True
|
||||||
# Noise scheduler.
|
# Noise scheduler.
|
||||||
|
noise_scheduler_type: DDPM
|
||||||
num_train_timesteps: 100
|
num_train_timesteps: 100
|
||||||
beta_schedule: squaredcos_cap_v2
|
beta_schedule: squaredcos_cap_v2
|
||||||
beta_start: 0.0001
|
beta_start: 0.0001
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
# @package _global_
|
# @package _global_
|
||||||
|
|
||||||
seed: 1
|
seed: 1
|
||||||
dataset_repo_id: lerobot/xarm_lift_medium_replay
|
dataset_repo_id: lerobot/xarm_lift_medium
|
||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: 25000
|
offline_steps: 25000
|
||||||
online_steps: 25000
|
# TODO(alexander-soare): uncomment when online training gets reinstated
|
||||||
|
online_steps: 0 # 25000 not implemented yet
|
||||||
eval_freq: 5000
|
eval_freq: 5000
|
||||||
online_steps_between_rollouts: 1
|
online_steps_between_rollouts: 1
|
||||||
online_sampling_ratio: 0.5
|
online_sampling_ratio: 0.5
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Evaluate a policy on an environment by running rollouts and computing metrics.
|
"""Evaluate a policy on an environment by running rollouts and computing metrics.
|
||||||
|
|
||||||
Usage examples:
|
Usage examples:
|
||||||
|
@ -31,6 +46,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime as dt
|
from datetime import datetime as dt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -505,7 +521,7 @@ def eval(
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
get_safe_torch_device(hydra_cfg.device, log=True)
|
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -524,6 +540,7 @@ def eval(
|
||||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
|
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||||
info = eval_policy(
|
info = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy,
|
policy,
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
|
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
|
||||||
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
||||||
|
@ -10,7 +25,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--dataset-id pusht \
|
--dataset-id pusht \
|
||||||
--raw-format pusht_zarr \
|
--raw-format pusht_zarr \
|
||||||
--community-id lerobot \
|
--community-id lerobot \
|
||||||
--revision v1.2 \
|
|
||||||
--dry-run 1 \
|
--dry-run 1 \
|
||||||
--save-to-disk 1 \
|
--save-to-disk 1 \
|
||||||
--save-tests-to-disk 0 \
|
--save-tests-to-disk 0 \
|
||||||
|
@ -21,7 +35,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--dataset-id xarm_lift_medium \
|
--dataset-id xarm_lift_medium \
|
||||||
--raw-format xarm_pkl \
|
--raw-format xarm_pkl \
|
||||||
--community-id lerobot \
|
--community-id lerobot \
|
||||||
--revision v1.2 \
|
|
||||||
--dry-run 1 \
|
--dry-run 1 \
|
||||||
--save-to-disk 1 \
|
--save-to-disk 1 \
|
||||||
--save-tests-to-disk 0 \
|
--save-tests-to-disk 0 \
|
||||||
|
@ -32,7 +45,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--dataset-id aloha_sim_insertion_scripted \
|
--dataset-id aloha_sim_insertion_scripted \
|
||||||
--raw-format aloha_hdf5 \
|
--raw-format aloha_hdf5 \
|
||||||
--community-id lerobot \
|
--community-id lerobot \
|
||||||
--revision v1.2 \
|
|
||||||
--dry-run 1 \
|
--dry-run 1 \
|
||||||
--save-to-disk 1 \
|
--save-to-disk 1 \
|
||||||
--save-tests-to-disk 0 \
|
--save-tests-to-disk 0 \
|
||||||
|
@ -43,7 +55,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--dataset-id umi_cup_in_the_wild \
|
--dataset-id umi_cup_in_the_wild \
|
||||||
--raw-format umi_zarr \
|
--raw-format umi_zarr \
|
||||||
--community-id lerobot \
|
--community-id lerobot \
|
||||||
--revision v1.2 \
|
|
||||||
--dry-run 1 \
|
--dry-run 1 \
|
||||||
--save-to-disk 1 \
|
--save-to-disk 1 \
|
||||||
--save-tests-to-disk 0 \
|
--save-tests-to-disk 0 \
|
||||||
|
@ -212,8 +223,7 @@ def push_dataset_to_hub(
|
||||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||||
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
|
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
|
||||||
|
|
||||||
# copy meta data to tests directory
|
save_meta_data(info, stats, episode_data_index, tests_meta_data_dir)
|
||||||
shutil.copytree(meta_data_dir, tests_meta_data_dir)
|
|
||||||
|
|
||||||
# copy videos of first episode to tests directory
|
# copy videos of first episode to tests directory
|
||||||
episode_index = 0
|
episode_index = 0
|
||||||
|
@ -222,6 +232,10 @@ def push_dataset_to_hub(
|
||||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||||
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
||||||
|
|
||||||
|
if not save_to_disk and out_dir.exists():
|
||||||
|
# remove possible temporary files remaining in the output directory
|
||||||
|
shutil.rmtree(out_dir)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -299,7 +313,7 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=8,
|
||||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
help="Number of processes of Dataloader for computing the dataset statistics.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
@ -1,13 +1,28 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
from datasets import concatenate_datasets
|
from omegaconf import DictConfig
|
||||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
from torch.cuda.amp import GradScaler
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
|
@ -15,6 +30,7 @@ from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
||||||
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
format_big_number,
|
format_big_number,
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
|
@ -53,7 +69,6 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||||
cfg.training.adam_eps,
|
cfg.training.adam_eps,
|
||||||
cfg.training.adam_weight_decay,
|
cfg.training.adam_weight_decay,
|
||||||
)
|
)
|
||||||
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
|
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
|
@ -71,20 +86,40 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||||
return optimizer, lr_scheduler
|
return optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
def update_policy(
|
||||||
start_time = time.time()
|
policy,
|
||||||
|
batch,
|
||||||
|
optimizer,
|
||||||
|
grad_clip_norm,
|
||||||
|
grad_scaler: GradScaler,
|
||||||
|
lr_scheduler=None,
|
||||||
|
use_amp: bool = False,
|
||||||
|
):
|
||||||
|
"""Returns a dictionary of items for logging."""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
device = get_device_from_parameters(policy)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||||
output_dict = policy.forward(batch)
|
output_dict = policy.forward(batch)
|
||||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
loss = output_dict["loss"]
|
loss = output_dict["loss"]
|
||||||
loss.backward()
|
grad_scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
|
||||||
|
grad_scaler.unscale_(optimizer)
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
policy.parameters(),
|
policy.parameters(),
|
||||||
grad_clip_norm,
|
grad_clip_norm,
|
||||||
error_if_nonfinite=False,
|
error_if_nonfinite=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer.step()
|
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||||
|
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||||
|
grad_scaler.step(optimizer)
|
||||||
|
# Updates the scale for next iteration.
|
||||||
|
grad_scaler.update()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
|
@ -98,7 +133,8 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
"grad_norm": float(grad_norm),
|
"grad_norm": float(grad_norm),
|
||||||
"lr": optimizer.param_groups[0]["lr"],
|
"lr": optimizer.param_groups[0]["lr"],
|
||||||
"update_s": time.time() - start_time,
|
"update_s": time.perf_counter() - start_time,
|
||||||
|
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||||
}
|
}
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
@ -122,7 +158,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
|
||||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||||
|
|
||||||
|
|
||||||
def log_train_info(logger, info, step, cfg, dataset, is_offline):
|
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||||
loss = info["loss"]
|
loss = info["loss"]
|
||||||
grad_norm = info["grad_norm"]
|
grad_norm = info["grad_norm"]
|
||||||
lr = info["lr"]
|
lr = info["lr"]
|
||||||
|
@ -193,104 +229,7 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
||||||
logger.log_dict(info, step, mode="eval")
|
logger.log_dict(info, step, mode="eval")
|
||||||
|
|
||||||
|
|
||||||
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||||
"""
|
|
||||||
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- n_off (int): Number of offline samples, each with a sampling weight of 1.
|
|
||||||
- n_on (int): Number of online samples.
|
|
||||||
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
|
|
||||||
|
|
||||||
The total weight of offline samples is n_off * 1.0.
|
|
||||||
The total weight of offline samples is n_on * w.
|
|
||||||
The total combined weight of all samples is n_off + n_on * w.
|
|
||||||
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
|
|
||||||
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
|
|
||||||
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
|
|
||||||
"""
|
|
||||||
assert 0.0 <= pc_on <= 1.0
|
|
||||||
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
|
||||||
|
|
||||||
|
|
||||||
def add_episodes_inplace(
|
|
||||||
online_dataset: torch.utils.data.Dataset,
|
|
||||||
concat_dataset: torch.utils.data.ConcatDataset,
|
|
||||||
sampler: torch.utils.data.WeightedRandomSampler,
|
|
||||||
hf_dataset: datasets.Dataset,
|
|
||||||
episode_data_index: dict[str, torch.Tensor],
|
|
||||||
pc_online_samples: float,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
|
|
||||||
new episodes from hf_dataset into the online_dataset, updating the concatenated
|
|
||||||
dataset's structure and adjusting the sampling strategy based on the specified
|
|
||||||
percentage of online samples.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
|
|
||||||
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
|
|
||||||
offline and online datasets, used for sampling purposes.
|
|
||||||
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
|
||||||
reflect changes in the dataset sizes and specified sampling weights.
|
|
||||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
|
||||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
|
||||||
They indicate the start index and end index of each episode in the dataset.
|
|
||||||
- pc_online_samples (float): The target percentage of samples that should come from
|
|
||||||
the online dataset during sampling operations.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
|
||||||
"""
|
|
||||||
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
|
|
||||||
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
|
|
||||||
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
|
||||||
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
|
|
||||||
# sanity check
|
|
||||||
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
|
|
||||||
assert first_index == 0, f"{first_index=} is not 0"
|
|
||||||
assert first_index == episode_data_index["from"][first_episode_idx].item()
|
|
||||||
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
|
|
||||||
|
|
||||||
if len(online_dataset) == 0:
|
|
||||||
# initialize online dataset
|
|
||||||
online_dataset.hf_dataset = hf_dataset
|
|
||||||
online_dataset.episode_data_index = episode_data_index
|
|
||||||
else:
|
|
||||||
# get the starting indices of the new episodes and frames to be added
|
|
||||||
start_episode_idx = last_episode_idx + 1
|
|
||||||
start_index = last_index + 1
|
|
||||||
|
|
||||||
def shift_indices(episode_index, index):
|
|
||||||
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
|
|
||||||
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
|
|
||||||
return example
|
|
||||||
|
|
||||||
disable_progress_bars() # map has a tqdm progress bar
|
|
||||||
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
|
|
||||||
enable_progress_bars()
|
|
||||||
|
|
||||||
episode_data_index["from"] += start_index
|
|
||||||
episode_data_index["to"] += start_index
|
|
||||||
|
|
||||||
# extend online dataset
|
|
||||||
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
|
||||||
|
|
||||||
# update the concatenated dataset length used during sampling
|
|
||||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
|
||||||
|
|
||||||
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
|
|
||||||
len_online = len(online_dataset)
|
|
||||||
len_offline = len(concat_dataset) - len_online
|
|
||||||
weight_offline = 1.0
|
|
||||||
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
|
|
||||||
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
|
|
||||||
|
|
||||||
# update the total number of samples used during sampling
|
|
||||||
sampler.num_samples = len(concat_dataset)
|
|
||||||
|
|
||||||
|
|
||||||
def train(cfg: dict, out_dir=None, job_name=None):
|
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if job_name is None:
|
if job_name is None:
|
||||||
|
@ -298,11 +237,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
|
if cfg.training.online_steps > 0:
|
||||||
logging.warning("eval.batch_size > 1 not supported for online training steps")
|
raise NotImplementedError("Online training is not implemented yet.")
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -320,6 +259,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
# Temporary hack to move optimizer out of policy
|
# Temporary hack to move optimizer out of policy
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
|
grad_scaler = GradScaler(enabled=cfg.use_amp)
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
@ -340,6 +280,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
def evaluate_and_checkpoint_if_needed(step):
|
def evaluate_and_checkpoint_if_needed(step):
|
||||||
if step % cfg.training.eval_freq == 0:
|
if step % cfg.training.eval_freq == 0:
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||||
eval_info = eval_policy(
|
eval_info = eval_policy(
|
||||||
eval_env,
|
eval_env,
|
||||||
policy,
|
policy,
|
||||||
|
@ -371,23 +312,30 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=device.type != "cpu",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
step = 0 # number of policy update (forward + backward + optim)
|
|
||||||
is_offline = True
|
is_offline = True
|
||||||
for offline_step in range(cfg.training.offline_steps):
|
for step in range(cfg.training.offline_steps):
|
||||||
if offline_step == 0:
|
if step == 0:
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
|
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(device, non_blocking=True)
|
||||||
|
|
||||||
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
|
train_info = update_policy(
|
||||||
|
policy,
|
||||||
|
batch,
|
||||||
|
optimizer,
|
||||||
|
cfg.training.grad_clip_norm,
|
||||||
|
grad_scaler=grad_scaler,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
use_amp=cfg.use_amp,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.training.log_freq == 0:
|
if step % cfg.training.log_freq == 0:
|
||||||
|
@ -397,11 +345,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
# so we pass in step + 1.
|
# so we pass in step + 1.
|
||||||
evaluate_and_checkpoint_if_needed(step + 1)
|
evaluate_and_checkpoint_if_needed(step + 1)
|
||||||
|
|
||||||
step += 1
|
|
||||||
|
|
||||||
# create an env dedicated to online episodes collection from policy rollout
|
|
||||||
online_training_env = make_env(cfg, n_envs=1)
|
|
||||||
|
|
||||||
# create an empty online dataset similar to offline dataset
|
# create an empty online dataset similar to offline dataset
|
||||||
online_dataset = deepcopy(offline_dataset)
|
online_dataset = deepcopy(offline_dataset)
|
||||||
online_dataset.hf_dataset = {}
|
online_dataset.hf_dataset = {}
|
||||||
|
@ -418,58 +361,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=device.type != "cpu",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
dl_iter = cycle(dataloader)
|
|
||||||
|
|
||||||
online_step = 0
|
|
||||||
is_offline = False
|
|
||||||
for env_step in range(cfg.training.online_steps):
|
|
||||||
if env_step == 0:
|
|
||||||
logging.info("Start online training by interacting with environment")
|
|
||||||
|
|
||||||
policy.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
eval_info = eval_policy(
|
|
||||||
online_training_env,
|
|
||||||
policy,
|
|
||||||
n_episodes=1,
|
|
||||||
return_episode_data=True,
|
|
||||||
start_seed=cfg.training.online_env_seed,
|
|
||||||
enable_progbar=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
add_episodes_inplace(
|
|
||||||
online_dataset,
|
|
||||||
concat_dataset,
|
|
||||||
sampler,
|
|
||||||
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
|
||||||
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
|
||||||
pc_online_samples=cfg.training.online_sampling_ratio,
|
|
||||||
)
|
|
||||||
|
|
||||||
policy.train()
|
|
||||||
for _ in range(cfg.training.online_steps_between_rollouts):
|
|
||||||
batch = next(dl_iter)
|
|
||||||
|
|
||||||
for key in batch:
|
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
|
||||||
|
|
||||||
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
|
|
||||||
|
|
||||||
if step % cfg.training.log_freq == 0:
|
|
||||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
|
||||||
|
|
||||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
|
||||||
# so we pass in step + 1.
|
|
||||||
evaluate_and_checkpoint_if_needed(step + 1)
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
online_step += 1
|
|
||||||
|
|
||||||
eval_env.close()
|
eval_env.close()
|
||||||
online_training_env.close()
|
|
||||||
logging.info("End of training")
|
logging.info("End of training")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
||||||
|
|
||||||
Note: The last frame of the episode doesnt always correspond to a final state.
|
Note: The last frame of the episode doesnt always correspond to a final state.
|
||||||
|
@ -47,6 +62,7 @@ local$ rerun ws://localhost:9087
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -115,15 +131,17 @@ def visualize_dataset(
|
||||||
|
|
||||||
spawn_local_viewer = mode == "local" and not save
|
spawn_local_viewer = mode == "local" and not save
|
||||||
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
|
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
|
||||||
|
|
||||||
|
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
|
||||||
|
# when iterating on a dataloader with `num_workers` > 0
|
||||||
|
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
if mode == "distant":
|
if mode == "distant":
|
||||||
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
|
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
|
||||||
|
|
||||||
logging.info("Logging to Rerun")
|
logging.info("Logging to Rerun")
|
||||||
|
|
||||||
if num_workers > 0:
|
|
||||||
# TODO(rcadene): fix data workers hanging when `rr.init` is called
|
|
||||||
logging.warning("If data loader is hanging, try `--num-workers 0`.")
|
|
||||||
|
|
||||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||||
# iterate over the batch
|
# iterate over the batch
|
||||||
for i in range(len(batch["index"])):
|
for i in range(len(batch["index"])):
|
||||||
|
@ -196,7 +214,7 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=4,
|
||||||
help="Number of processes of Dataloader for loading the data.",
|
help="Number of processes of Dataloader for loading the data.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -28,37 +28,36 @@ packages = [{include = "lerobot"}]
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.10,<3.13"
|
python = ">=3.10,<3.13"
|
||||||
termcolor = "^2.4.0"
|
termcolor = ">=2.4.0"
|
||||||
omegaconf = "^2.3.0"
|
omegaconf = ">=2.3.0"
|
||||||
wandb = "^0.16.3"
|
wandb = ">=0.16.3"
|
||||||
imageio = {extras = ["ffmpeg"], version = "^2.34.0"}
|
imageio = {extras = ["ffmpeg"], version = ">=2.34.0"}
|
||||||
gdown = "^5.1.0"
|
gdown = ">=5.1.0"
|
||||||
hydra-core = "^1.3.2"
|
hydra-core = ">=1.3.2"
|
||||||
einops = "^0.8.0"
|
einops = ">=0.8.0"
|
||||||
pymunk = "^6.6.0"
|
pymunk = ">=6.6.0"
|
||||||
zarr = "^2.17.0"
|
zarr = ">=2.17.0"
|
||||||
numba = "^0.59.0"
|
numba = ">=0.59.0"
|
||||||
torch = "^2.2.1"
|
torch = "^2.2.1"
|
||||||
opencv-python = "^4.9.0.80"
|
opencv-python = ">=4.9.0"
|
||||||
diffusers = "^0.27.2"
|
diffusers = "^0.27.2"
|
||||||
torchvision = "^0.18.0"
|
torchvision = ">=0.17.1"
|
||||||
h5py = "^3.10.0"
|
h5py = ">=3.10.0"
|
||||||
huggingface-hub = "^0.21.4"
|
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
|
||||||
robomimic = "0.2.0"
|
gymnasium = ">=0.29.1"
|
||||||
gymnasium = "^0.29.1"
|
cmake = ">=3.29.0.1"
|
||||||
cmake = "^3.29.0.1"
|
gym-pusht = { version = ">=0.1.3", optional = true}
|
||||||
gym-pusht = { version = "^0.1.0", optional = true}
|
gym-xarm = { version = ">=0.1.1", optional = true}
|
||||||
gym-xarm = { version = "^0.1.0", optional = true}
|
gym-aloha = { version = ">=0.1.1", optional = true}
|
||||||
gym-aloha = { version = "^0.1.0", optional = true}
|
pre-commit = {version = ">=3.7.0", optional = true}
|
||||||
pre-commit = {version = "^3.7.0", optional = true}
|
debugpy = {version = ">=1.8.1", optional = true}
|
||||||
debugpy = {version = "^1.8.1", optional = true}
|
pytest = {version = ">=8.1.0", optional = true}
|
||||||
pytest = {version = "^8.1.0", optional = true}
|
pytest-cov = {version = ">=5.0.0", optional = true}
|
||||||
pytest-cov = {version = "^5.0.0", optional = true}
|
|
||||||
datasets = "^2.19.0"
|
datasets = "^2.19.0"
|
||||||
imagecodecs = { version = "^2024.1.1", optional = true }
|
imagecodecs = { version = ">=2024.1.1", optional = true }
|
||||||
pyav = "^12.0.5"
|
pyav = ">=12.0.5"
|
||||||
moviepy = "^1.0.3"
|
moviepy = ">=1.0.3"
|
||||||
rerun-sdk = "^0.15.1"
|
rerun-sdk = ">=0.15.1"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
@ -104,5 +103,5 @@ ignore-init-module-imports = true
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core>=1.5.0"]
|
requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
from .utils import DEVICE
|
from .utils import DEVICE
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:9f9347c8d9ac90ee44e6dd86f65043438168df6bbe4bab2d2b875e55ef7376ef
|
||||||
|
size 1488
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
||||||
|
size 33
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:02fc4ea25766269f65752a60b0594c43d799b0ae528cd773bf024b064b5aa329
|
||||||
|
size 4344
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:55d7b1a06fe3e3051482752740074348bdb5fc98fb2e305b06d6203994117b27
|
||||||
|
size 592448
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8
|
||||||
|
size 1166
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:98329e4b40e9be0d63f7d36da9d86c44bbe7eeeb1b10d3ba973c923f3be70867
|
||||||
|
size 247
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:54e42cdfd016a0ced2ab1fe2966a8c15a2384e0dbe1a2fe87433a2d1b8209ac0
|
||||||
|
size 5220057
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:af1ded2a244cb47a96255b75f584a643edf6967e13bb5464b330ffdd9d7ad859
|
||||||
|
size 5284692
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:13d1bebabd79984fd6715971be758ef9a354495adea5e8d33f4e7904365e112b
|
||||||
|
size 5258380
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:f33bc6810f0b91817a42610364cb49ed1b99660f058f0f9407e6f5920d0aee02
|
||||||
|
size 1008
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
||||||
|
size 33
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:7b58d6c89e936a781a307805ebecf0dd473fbc02d52a7094da62e54bffb9454a
|
||||||
|
size 4344
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:a08be578285cbe2d35b78f150d464ff3e10604a9865398c976983e0d711774f9
|
||||||
|
size 788528
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8
|
||||||
|
size 1166
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:34e36233477c8aa0b0840314ddace072062d4f486d06546bbd6550832c370065
|
||||||
|
size 247
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:66e7349a4a82ca6042a7189608d01eb1cfa38d100d039b5445ae1a9e65d824ab
|
||||||
|
size 14470946
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:a2146f0c10c9f2611e57e617983aa4f91ad681b4fc50d91b992b97abd684f926
|
||||||
|
size 11662185
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:5affbaf1c48895ba3c626e0d8cf1309e5f4ec6bbaa135313096f52a22de66c05
|
||||||
|
size 11410342
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:6c2b195ca91b88fd16422128d386d2cabd808a1862c6d127e6bf2e83e1fe819a
|
||||||
|
size 448
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
||||||
|
size 33
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:b360b6b956d2adcb20589947c553348ef1eb6b70743c989dcbe95243d8592ce5
|
||||||
|
size 4344
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:3f5c3926b4d4da9271abefcdf6a8952bb1f13258a9c39fe0fd223f548dc89dcb
|
||||||
|
size 887728
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8
|
||||||
|
size 1166
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:4993b05fb026619eec5eb70db8cadaa041ba4ab92d38b4a387167ace03b1018b
|
||||||
|
size 247
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:bd25d17ef5b7500386761b5e32920879bbdcafe0e17a8a8845628525d861e644
|
||||||
|
size 10231081
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:5b557acbfeb0681c0a38e47263d945f6cd3a03461298d8b17209c81e3fd0aae8
|
||||||
|
size 9701371
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:da8f3b4f9f965da63819652b2c042d4cf7e07d14631113ea072087d56370310e
|
||||||
|
size 10473741
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:a053506017d8a78cfd307b2912eeafa1ac1485a280cf90913985fcc40120b5ec
|
||||||
|
size 416
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
||||||
|
size 33
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:d6d172d1bca02face22ceb4c21ea2b054cf3463025485dce64711b6f36b31f8a
|
||||||
|
size 4344
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:7e5ce817a2c188041f57f8d4c465dab3b9c3e4e1aeb7a9fb270230d1b36df530
|
||||||
|
size 1477064
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8
|
||||||
|
size 1166
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:4eb2dc373e4ea7d474742590f9073d66a773f6ab94b9e73a8673df19f93fae6d
|
||||||
|
size 247
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:d2c55b146fabe78b18c8a28a7746ab56e1ee7a6918e9e3dad9bd196f97975895
|
||||||
|
size 26158915
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:71e1958d77f56843acf1ec48da4f04311a5836c87a0e77dbe26aa47c27c6347e
|
||||||
|
size 18786848
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:20780718399b5759ff9a3a79824986310524793066198e3b9a307222f11a93df
|
||||||
|
size 17769988
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:279916f7689ae46af90e92a46eba9486a71fc762e3e2679ab5441eb37126827b
|
||||||
|
size 928
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
||||||
|
size 33
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:7a7731051b521694b52b5631470720a7f05331915f4ac4e7f8cd83f9ff459bce
|
||||||
|
size 4344
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:99608258e8c9fe5191f1a12edc29b47d307790104149dffb6d3046ddad6aeb1b
|
||||||
|
size 435600
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:8b7fbedfdb3d536847bc6fadf2cbabb9f2b5492edf3e2c274a3e8ffb447105e8
|
||||||
|
size 1166
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:ae6735b7b394914824e974a7461019373a10f9e2d84ddf834bec8ea268d9ec1e
|
||||||
|
size 247
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue