Merge remote-tracking branch 'origin/2025_02_20_add_dexvla' into 2025_02_20_add_dexvla
# Conflicts: # lerobot/common/policies/dexvla/configuration_dexvla.py # lerobot/common/policies/dexvla/modeling_dexvla.py
This commit is contained in:
commit
628ba6e545
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Misc
|
# Misc
|
||||||
.git
|
.git
|
||||||
tmp
|
tmp
|
||||||
|
@ -59,7 +73,7 @@ pip-log.txt
|
||||||
pip-delete-this-directory.txt
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
# Unit test / coverage reports
|
# Unit test / coverage reports
|
||||||
!tests/data
|
!tests/artifacts
|
||||||
htmlcov/
|
htmlcov/
|
||||||
.tox/
|
.tox/
|
||||||
.nox/
|
.nox/
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
*.memmap filter=lfs diff=lfs merge=lfs -text
|
*.memmap filter=lfs diff=lfs merge=lfs -text
|
||||||
*.stl filter=lfs diff=lfs merge=lfs -text
|
*.stl filter=lfs diff=lfs merge=lfs -text
|
||||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
name: "\U0001F41B Bug Report"
|
name: "\U0001F41B Bug Report"
|
||||||
description: Submit a bug report to help us improve LeRobot
|
description: Submit a bug report to help us improve LeRobot
|
||||||
body:
|
body:
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Inspired by
|
# Inspired by
|
||||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml
|
# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml
|
||||||
name: Builds
|
name: Builds
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Inspired by
|
# Inspired by
|
||||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml
|
# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml
|
||||||
name: Nightly
|
name: Nightly
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
name: Quality
|
name: Quality
|
||||||
|
|
||||||
on:
|
on:
|
||||||
|
@ -32,13 +46,27 @@ jobs:
|
||||||
id: get-ruff-version
|
id: get-ruff-version
|
||||||
run: |
|
run: |
|
||||||
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
|
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
|
||||||
echo "RUFF_VERSION=${RUFF_VERSION}" >> $GITHUB_ENV
|
echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Install Ruff
|
- name: Install Ruff
|
||||||
run: python -m pip install "ruff==${{ env.RUFF_VERSION }}"
|
env:
|
||||||
|
RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
|
||||||
|
run: python -m pip install "ruff==${RUFF_VERSION}"
|
||||||
|
|
||||||
- name: Ruff check
|
- name: Ruff check
|
||||||
run: ruff check --output-format=github
|
run: ruff check --output-format=github
|
||||||
|
|
||||||
- name: Ruff format
|
- name: Ruff format
|
||||||
run: ruff format --diff
|
run: ruff format --diff
|
||||||
|
|
||||||
|
typos:
|
||||||
|
name: Typos
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout Repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
|
|
||||||
|
- name: typos-action
|
||||||
|
uses: crate-ci/typos@v1.29.10
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Inspired by
|
# Inspired by
|
||||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml
|
# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml
|
||||||
name: Test Dockerfiles
|
name: Test Dockerfiles
|
||||||
|
@ -43,7 +57,7 @@ jobs:
|
||||||
needs: get_changed_files
|
needs: get_changed_files
|
||||||
runs-on:
|
runs-on:
|
||||||
group: aws-general-8-plus
|
group: aws-general-8-plus
|
||||||
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
|
if: needs.get_changed_files.outputs.matrix != ''
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
name: Tests
|
name: Tests
|
||||||
|
|
||||||
on:
|
on:
|
||||||
|
@ -112,7 +126,7 @@ jobs:
|
||||||
# portaudio19-dev is needed to install pyaudio
|
# portaudio19-dev is needed to install pyaudio
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update && \
|
sudo apt-get update && \
|
||||||
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||||
|
|
||||||
- name: Install uv and python
|
- name: Install uv and python
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v5
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
logs
|
logs
|
||||||
tmp
|
tmp
|
||||||
|
@ -64,7 +78,7 @@ pip-log.txt
|
||||||
pip-delete-this-directory.txt
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
# Unit test / coverage reports
|
# Unit test / coverage reports
|
||||||
!tests/data
|
!tests/artifacts
|
||||||
htmlcov/
|
htmlcov/
|
||||||
.tox/
|
.tox/
|
||||||
.nox/
|
.nox/
|
||||||
|
|
|
@ -1,7 +1,29 @@
|
||||||
exclude: ^(tests/data)
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
exclude: "tests/artifacts/.*\\.safetensors$"
|
||||||
default_language_version:
|
default_language_version:
|
||||||
python: python3.10
|
python: python3.10
|
||||||
repos:
|
repos:
|
||||||
|
##### Meta #####
|
||||||
|
- repo: meta
|
||||||
|
hooks:
|
||||||
|
- id: check-useless-excludes
|
||||||
|
- id: check-hooks-apply
|
||||||
|
|
||||||
|
|
||||||
|
##### Style / Misc. #####
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v5.0.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
|
@ -13,21 +35,40 @@ repos:
|
||||||
- id: check-toml
|
- id: check-toml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
|
|
||||||
|
- repo: https://github.com/crate-ci/typos
|
||||||
|
rev: v1.30.2
|
||||||
|
hooks:
|
||||||
|
- id: typos
|
||||||
|
args: [--force-exclude]
|
||||||
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v3.19.1
|
rev: v3.19.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.9.6
|
rev: v0.9.10
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|
||||||
|
|
||||||
|
##### Security #####
|
||||||
- repo: https://github.com/gitleaks/gitleaks
|
- repo: https://github.com/gitleaks/gitleaks
|
||||||
rev: v8.23.3
|
rev: v8.24.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: gitleaks
|
- id: gitleaks
|
||||||
|
|
||||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||||
rev: v1.3.1
|
rev: v1.4.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: zizmor
|
- id: zizmor
|
||||||
|
|
||||||
|
- repo: https://github.com/PyCQA/bandit
|
||||||
|
rev: 1.8.3
|
||||||
|
hooks:
|
||||||
|
- id: bandit
|
||||||
|
args: ["-c", "pyproject.toml"]
|
||||||
|
additional_dependencies: ["bandit[toml]"]
|
||||||
|
|
|
@ -228,7 +228,7 @@ Follow these steps to start contributing:
|
||||||
git commit
|
git commit
|
||||||
```
|
```
|
||||||
|
|
||||||
Note, if you already commited some changes that have a wrong formatting, you can use:
|
Note, if you already committed some changes that have a wrong formatting, you can use:
|
||||||
```bash
|
```bash
|
||||||
pre-commit run --all-files
|
pre-commit run --all-files
|
||||||
```
|
```
|
||||||
|
@ -291,7 +291,7 @@ sudo apt-get install git-lfs
|
||||||
git lfs install
|
git lfs install
|
||||||
```
|
```
|
||||||
|
|
||||||
Pull artifacts if they're not in [tests/data](tests/data)
|
Pull artifacts if they're not in [tests/artifacts](tests/artifacts)
|
||||||
```bash
|
```bash
|
||||||
git lfs pull
|
git lfs pull
|
||||||
```
|
```
|
||||||
|
|
32
Makefile
32
Makefile
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
.PHONY: tests
|
.PHONY: tests
|
||||||
|
|
||||||
PYTHON_PATH := $(shell which python)
|
PYTHON_PATH := $(shell which python)
|
||||||
|
@ -33,6 +47,7 @@ test-act-ete-train:
|
||||||
--policy.dim_model=64 \
|
--policy.dim_model=64 \
|
||||||
--policy.n_action_steps=20 \
|
--policy.n_action_steps=20 \
|
||||||
--policy.chunk_size=20 \
|
--policy.chunk_size=20 \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=aloha \
|
--env.type=aloha \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||||
|
@ -47,7 +62,6 @@ test-act-ete-train:
|
||||||
--save_checkpoint=true \
|
--save_checkpoint=true \
|
||||||
--log_freq=1 \
|
--log_freq=1 \
|
||||||
--wandb.enable=false \
|
--wandb.enable=false \
|
||||||
--device=$(DEVICE) \
|
|
||||||
--output_dir=tests/outputs/act/
|
--output_dir=tests/outputs/act/
|
||||||
|
|
||||||
test-act-ete-train-resume:
|
test-act-ete-train-resume:
|
||||||
|
@ -58,11 +72,11 @@ test-act-ete-train-resume:
|
||||||
test-act-ete-eval:
|
test-act-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
|
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=aloha \
|
--env.type=aloha \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1
|
||||||
--device=$(DEVICE)
|
|
||||||
|
|
||||||
test-diffusion-ete-train:
|
test-diffusion-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
|
@ -70,6 +84,7 @@ test-diffusion-ete-train:
|
||||||
--policy.down_dims='[64,128,256]' \
|
--policy.down_dims='[64,128,256]' \
|
||||||
--policy.diffusion_step_embed_dim=32 \
|
--policy.diffusion_step_embed_dim=32 \
|
||||||
--policy.num_inference_steps=10 \
|
--policy.num_inference_steps=10 \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=pusht \
|
--env.type=pusht \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--dataset.repo_id=lerobot/pusht \
|
--dataset.repo_id=lerobot/pusht \
|
||||||
|
@ -84,21 +99,21 @@ test-diffusion-ete-train:
|
||||||
--save_freq=2 \
|
--save_freq=2 \
|
||||||
--log_freq=1 \
|
--log_freq=1 \
|
||||||
--wandb.enable=false \
|
--wandb.enable=false \
|
||||||
--device=$(DEVICE) \
|
|
||||||
--output_dir=tests/outputs/diffusion/
|
--output_dir=tests/outputs/diffusion/
|
||||||
|
|
||||||
test-diffusion-ete-eval:
|
test-diffusion-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=pusht \
|
--env.type=pusht \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1
|
||||||
--device=$(DEVICE)
|
|
||||||
|
|
||||||
test-tdmpc-ete-train:
|
test-tdmpc-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
--policy.type=tdmpc \
|
--policy.type=tdmpc \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=xarm \
|
--env.type=xarm \
|
||||||
--env.task=XarmLift-v0 \
|
--env.task=XarmLift-v0 \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
|
@ -114,15 +129,14 @@ test-tdmpc-ete-train:
|
||||||
--save_freq=2 \
|
--save_freq=2 \
|
||||||
--log_freq=1 \
|
--log_freq=1 \
|
||||||
--wandb.enable=false \
|
--wandb.enable=false \
|
||||||
--device=$(DEVICE) \
|
|
||||||
--output_dir=tests/outputs/tdmpc/
|
--output_dir=tests/outputs/tdmpc/
|
||||||
|
|
||||||
test-tdmpc-ete-eval:
|
test-tdmpc-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=xarm \
|
--env.type=xarm \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--env.task=XarmLift-v0 \
|
--env.task=XarmLift-v0 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1
|
||||||
--device=$(DEVICE)
|
|
||||||
|
|
28
README.md
28
README.md
|
@ -23,15 +23,24 @@
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<h2 align="center">
|
<h2 align="center">
|
||||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">New robot in town: SO-100</a></p>
|
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">
|
||||||
|
Build Your Own SO-100 Robot!</a></p>
|
||||||
</h2>
|
</h2>
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<img src="media/so100/leader_follower.webp?raw=true" alt="SO-100 leader and follower arms" title="SO-100 leader and follower arms" width="50%">
|
<img src="media/so100/leader_follower.webp?raw=true" alt="SO-100 leader and follower arms" title="SO-100 leader and follower arms" width="50%">
|
||||||
<p>We just added a new tutorial on how to build a more affordable robot, at the price of $110 per arm!</p>
|
|
||||||
<p>Teach it new skills by showing it a few moves with just a laptop.</p>
|
<p><strong>Meet the SO-100 – Just $110 per arm!</strong></p>
|
||||||
<p>Then watch your homemade robot act autonomously 🤯</p>
|
<p>Train it in minutes with a few simple moves on your laptop.</p>
|
||||||
<p>Follow the link to the <a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">full tutorial for SO-100</a>.</p>
|
<p>Then sit back and watch your creation act autonomously! 🤯</p>
|
||||||
|
|
||||||
|
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">
|
||||||
|
Get the full SO-100 tutorial here.</a></p>
|
||||||
|
|
||||||
|
<p>Want to take it to the next level? Make your SO-100 mobile by building LeKiwi!</p>
|
||||||
|
<p>Check out the <a href="https://github.com/huggingface/lerobot/blob/main/examples/11_use_lekiwi.md">LeKiwi tutorial</a> and bring your robot to life on wheels.</p>
|
||||||
|
|
||||||
|
<img src="media/lekiwi/kiwi.webp?raw=true" alt="LeKiwi mobile robot" title="LeKiwi mobile robot" width="50%">
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<br/>
|
<br/>
|
||||||
|
@ -210,7 +219,7 @@ A `LeRobotDataset` is serialised using several widespread file formats for each
|
||||||
- videos are stored in mp4 format to save space
|
- videos are stored in mp4 format to save space
|
||||||
- metadata are stored in plain json/jsonl files
|
- metadata are stored in plain json/jsonl files
|
||||||
|
|
||||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can use the `local_files_only` argument and specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
||||||
|
|
||||||
### Evaluate a pretrained policy
|
### Evaluate a pretrained policy
|
||||||
|
|
||||||
|
@ -223,8 +232,8 @@ python lerobot/scripts/eval.py \
|
||||||
--env.type=pusht \
|
--env.type=pusht \
|
||||||
--eval.batch_size=10 \
|
--eval.batch_size=10 \
|
||||||
--eval.n_episodes=10 \
|
--eval.n_episodes=10 \
|
||||||
--use_amp=false \
|
--policy.use_amp=false \
|
||||||
--device=cuda
|
--policy.device=cuda
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
||||||
|
@ -375,3 +384,6 @@ Additionally, if you are using any of the particular policy architecture, pretra
|
||||||
year={2024}
|
year={2024}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
## Star History
|
||||||
|
|
||||||
|
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||||
|
|
|
@ -114,7 +114,7 @@ We tried to measure the most impactful parameters for both encoding and decoding
|
||||||
|
|
||||||
Additional encoding parameters exist that are not included in this benchmark. In particular:
|
Additional encoding parameters exist that are not included in this benchmark. In particular:
|
||||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||||
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
|
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
|
||||||
|
|
||||||
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@ def parse_int_or_none(value) -> int | None:
|
||||||
def check_datasets_formats(repo_ids: list) -> None:
|
def check_datasets_formats(repo_ids: list) -> None:
|
||||||
for repo_id in repo_ids:
|
for repo_id in repo_ids:
|
||||||
dataset = LeRobotDataset(repo_id)
|
dataset = LeRobotDataset(repo_id)
|
||||||
if dataset.video:
|
if len(dataset.meta.video_keys) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,33 +1,29 @@
|
||||||
# Configure image
|
# Configure image
|
||||||
ARG PYTHON_VERSION=3.10
|
ARG PYTHON_VERSION=3.10
|
||||||
|
|
||||||
FROM python:${PYTHON_VERSION}-slim
|
FROM python:${PYTHON_VERSION}-slim
|
||||||
ARG PYTHON_VERSION
|
|
||||||
ARG DEBIAN_FRONTEND=noninteractive
|
|
||||||
|
|
||||||
# Install apt dependencies
|
# Configure environment variables
|
||||||
|
ARG PYTHON_VERSION
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV MUJOCO_GL="egl"
|
||||||
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
|
# Install dependencies and set up Python in a single layer
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
build-essential cmake git git-lfs \
|
build-essential cmake git \
|
||||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||||
speech-dispatcher libgeos-dev \
|
speech-dispatcher libgeos-dev \
|
||||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
|
||||||
|
&& python -m venv /opt/venv \
|
||||||
|
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||||
|
|
||||||
# Create virtual environment
|
# Clone repository and install LeRobot in a single layer
|
||||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
COPY . /lerobot
|
||||||
RUN python -m venv /opt/venv
|
|
||||||
ENV PATH="/opt/venv/bin:$PATH"
|
|
||||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
|
||||||
|
|
||||||
# Install LeRobot
|
|
||||||
RUN git lfs install
|
|
||||||
RUN git clone https://github.com/huggingface/lerobot.git /lerobot
|
|
||||||
WORKDIR /lerobot
|
WORKDIR /lerobot
|
||||||
RUN pip install --upgrade --no-cache-dir pip
|
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
|
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
|
||||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
# Set EGL as the rendering backend for MuJoCo
|
|
||||||
ENV MUJOCO_GL="egl"
|
|
||||||
|
|
||||||
# Execute in bash shell rather than python
|
# Execute in bash shell rather than python
|
||||||
CMD ["/bin/bash"]
|
CMD ["/bin/bash"]
|
||||||
|
|
|
@ -1,31 +1,24 @@
|
||||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
||||||
|
|
||||||
# Configure image
|
# Configure environment variables
|
||||||
ARG PYTHON_VERSION=3.10
|
ARG PYTHON_VERSION=3.10
|
||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV MUJOCO_GL="egl"
|
||||||
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
|
# Install dependencies and set up Python in a single layer
|
||||||
# Install apt dependencies
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
build-essential cmake git git-lfs \
|
build-essential cmake git \
|
||||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||||
speech-dispatcher libgeos-dev \
|
speech-dispatcher libgeos-dev \
|
||||||
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
|
||||||
|
&& python -m venv /opt/venv \
|
||||||
|
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||||
|
|
||||||
|
# Clone repository and install LeRobot in a single layer
|
||||||
# Create virtual environment
|
COPY . /lerobot
|
||||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
|
||||||
RUN python -m venv /opt/venv
|
|
||||||
ENV PATH="/opt/venv/bin:$PATH"
|
|
||||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
|
||||||
|
|
||||||
# Install LeRobot
|
|
||||||
RUN git lfs install
|
|
||||||
RUN git clone https://github.com/huggingface/lerobot.git /lerobot
|
|
||||||
WORKDIR /lerobot
|
WORKDIR /lerobot
|
||||||
RUN pip install --upgrade --no-cache-dir pip
|
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"
|
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"
|
||||||
|
|
||||||
# Set EGL as the rendering backend for MuJoCo
|
|
||||||
ENV MUJOCO_GL="egl"
|
|
||||||
|
|
|
@ -4,8 +4,8 @@
|
||||||
|
|
||||||
- [A. Source the parts](#a-source-the-parts)
|
- [A. Source the parts](#a-source-the-parts)
|
||||||
- [B. Install LeRobot](#b-install-lerobot)
|
- [B. Install LeRobot](#b-install-lerobot)
|
||||||
- [C. Configure the motors](#c-configure-the-motors)
|
- [C. Configure the Motors](#c-configure-the-motors)
|
||||||
- [D. Assemble the arms](#d-assemble-the-arms)
|
- [D. Step-by-Step Assembly Instructions](#d-step-by-step-assembly-instructions)
|
||||||
- [E. Calibrate](#e-calibrate)
|
- [E. Calibrate](#e-calibrate)
|
||||||
- [F. Teleoperate](#f-teleoperate)
|
- [F. Teleoperate](#f-teleoperate)
|
||||||
- [G. Record a dataset](#g-record-a-dataset)
|
- [G. Record a dataset](#g-record-a-dataset)
|
||||||
|
@ -70,6 +70,7 @@ conda install -y -c conda-forge "opencv>=4.10.0"
|
||||||
```
|
```
|
||||||
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
|
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
|
||||||
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
|
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
|
||||||
|
|
||||||
## C. Configure the motors
|
## C. Configure the motors
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
|
@ -98,22 +99,22 @@ Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem5
|
||||||
```
|
```
|
||||||
Finding all available ports for the MotorBus.
|
Finding all available ports for the MotorBus.
|
||||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
Remove the usb cable from your MotorsBus and press Enter when done.
|
||||||
|
|
||||||
[...Disconnect leader arm and press Enter...]
|
[...Disconnect leader arm and press Enter...]
|
||||||
|
|
||||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
The port of this MotorsBus is /dev/tty.usbmodem575E0031751
|
||||||
Reconnect the usb cable.
|
Reconnect the usb cable.
|
||||||
```
|
```
|
||||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||||
```
|
```
|
||||||
Finding all available ports for the MotorBus.
|
Finding all available ports for the MotorBus.
|
||||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
Remove the usb cable from your MotorsBus and press Enter when done.
|
||||||
|
|
||||||
[...Disconnect follower arm and press Enter...]
|
[...Disconnect follower arm and press Enter...]
|
||||||
|
|
||||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
|
||||||
Reconnect the usb cable.
|
Reconnect the usb cable.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -221,19 +222,13 @@ Redo the process for all your motors until ID 6. Do the same for the 6 motors of
|
||||||
|
|
||||||
Follow the video for removing gears. You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
Follow the video for removing gears. You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||||
|
|
||||||
#### c. Add motor horn to all 12 motors
|
## D. Step-by-Step Assembly Instructions
|
||||||
|
|
||||||
<details>
|
**Step 1: Clean Parts**
|
||||||
<summary><strong>Video adding motor horn</strong></summary>
|
- Remove all support material from the 3D-printed parts.
|
||||||
|
---
|
||||||
|
|
||||||
<video src="https://github.com/user-attachments/assets/ef3391a4-ad05-4100-b2bd-1699bf86c969"></video>
|
### Additional Guidance
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
Follow the video for adding the motor horn. For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
|
||||||
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
|
|
||||||
|
|
||||||
## D. Assemble the arms
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>Video assembling arms</strong></summary>
|
<summary><strong>Video assembling arms</strong></summary>
|
||||||
|
@ -242,7 +237,211 @@ Try to avoid rotating the motor while doing so to keep position 2048 set during
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
Follow the video for assembling the arms. It is important to insert the cables into the motor that is being assembled before you assemble the motor into the arm! Inserting the cables beforehand is much easier than doing this afterward. The first arm should take a bit more than 1 hour to assemble, but once you get used to it, you can do it under 1 hour for the second arm.
|
**Note:**
|
||||||
|
This video provides visual guidance for assembling the arms, but it doesn't specify when or how to do the wiring. Inserting the cables beforehand is much easier than doing it afterward. The first arm may take a bit more than 1 hour to assemble, but once you get used to it, you can assemble the second arm in under 1 hour.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### First Motor
|
||||||
|
|
||||||
|
**Step 2: Insert Wires**
|
||||||
|
- Insert two wires into the first motor.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img1.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 3: Install in Base**
|
||||||
|
- Place the first motor into the base.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img2.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 4: Secure Motor**
|
||||||
|
- Fasten the motor with 4 screws. Two from the bottom and two from top.
|
||||||
|
|
||||||
|
**Step 5: Attach Motor Holder**
|
||||||
|
- Slide over the first motor holder and fasten it using two screws (one on each side).
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img4.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 6: Attach Motor Horns**
|
||||||
|
- Install both motor horns, securing the top horn with a screw. Try not to move the motor position when attaching the motor horn, especially for the leader arms, where we removed the gears.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img5.jpg" style="height:300px;">
|
||||||
|
<details>
|
||||||
|
<summary><strong>Video adding motor horn</strong></summary>
|
||||||
|
<video src="https://github.com/user-attachments/assets/ef3391a4-ad05-4100-b2bd-1699bf86c969"></video>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
**Step 7: Attach Shoulder Part**
|
||||||
|
- Route one wire to the back of the robot and the other to the left or in photo towards you (see photo).
|
||||||
|
- Attach the shoulder part.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img6.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 8: Secure Shoulder**
|
||||||
|
- Tighten the shoulder part with 4 screws on top and 4 on the bottom
|
||||||
|
*(access bottom holes by turning the shoulder).*
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Second Motor Assembly
|
||||||
|
|
||||||
|
**Step 9: Install Motor 2**
|
||||||
|
- Slide the second motor in from the top and link the wire from motor 1 to motor 2.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img8.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 10: Attach Shoulder Holder**
|
||||||
|
- Add the shoulder motor holder.
|
||||||
|
- Ensure the wire from motor 1 to motor 2 goes behind the holder while the other wire is routed upward (see photo).
|
||||||
|
- This part can be tight to assemble, you can use a workbench like the image or a similar setup to push the part around the motor.
|
||||||
|
|
||||||
|
<div style="display: flex;">
|
||||||
|
<img src="../media/tutorial/img9.jpg" style="height:250px;">
|
||||||
|
<img src="../media/tutorial/img10.jpg" style="height:250px;">
|
||||||
|
<img src="../media/tutorial/img12.jpg" style="height:250px;">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
**Step 11: Secure Motor 2**
|
||||||
|
- Fasten the second motor with 4 screws.
|
||||||
|
|
||||||
|
**Step 12: Attach Motor Horn**
|
||||||
|
- Attach both motor horns to motor 2, again use the horn screw.
|
||||||
|
|
||||||
|
**Step 13: Attach Base**
|
||||||
|
- Install the base attachment using 2 screws.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img11.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 14: Attach Upper Arm**
|
||||||
|
- Attach the upper arm with 4 screws on each side.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img13.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Third Motor Assembly
|
||||||
|
|
||||||
|
**Step 15: Install Motor 3**
|
||||||
|
- Route the motor cable from motor 2 through the cable holder to motor 3, then secure motor 3 with 4 screws.
|
||||||
|
|
||||||
|
**Step 16: Attach Motor Horn**
|
||||||
|
- Attach both motor horns to motor 3 and secure one again with a horn screw.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img14.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 17: Attach Forearm**
|
||||||
|
- Connect the forearm to motor 3 using 4 screws on each side.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img15.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Fourth Motor Assembly
|
||||||
|
|
||||||
|
**Step 18: Install Motor 4**
|
||||||
|
- Slide in motor 4, attach the cable from motor 3, and secure the cable in its holder with a screw.
|
||||||
|
|
||||||
|
<div style="display: flex;">
|
||||||
|
<img src="../media/tutorial/img16.jpg" style="height:300px;">
|
||||||
|
<img src="../media/tutorial/img19.jpg" style="height:300px;">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
**Step 19: Attach Motor Holder 4**
|
||||||
|
- Install the fourth motor holder (a tight fit). Ensure one wire is routed upward and the wire from motor 3 is routed downward (see photo).
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img17.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 20: Secure Motor 4 & Attach Horn**
|
||||||
|
- Fasten motor 4 with 4 screws and attach its motor horns, use for one a horn screw.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img18.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Wrist Assembly
|
||||||
|
|
||||||
|
**Step 21: Install Motor 5**
|
||||||
|
- Insert motor 5 into the wrist holder and secure it with 2 front screws.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img20.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 22: Attach Wrist**
|
||||||
|
- Connect the wire from motor 4 to motor 5. And already insert the other wire for the gripper.
|
||||||
|
- Secure the wrist to motor 4 using 4 screws on both sides.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img22.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 23: Attach Wrist Horn**
|
||||||
|
- Install only one motor horn on the wrist motor and secure it with a horn screw.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img23.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Follower Configuration
|
||||||
|
|
||||||
|
**Step 24: Attach Gripper**
|
||||||
|
- Attach the gripper to motor 5.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img24.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 25: Install Gripper Motor**
|
||||||
|
- Insert the gripper motor, connect the motor wire from motor 5 to motor 6, and secure it with 3 screws on each side.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img25.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 26: Attach Gripper Horn & Claw**
|
||||||
|
- Attach the motor horns and again use a horn screw.
|
||||||
|
- Install the gripper claw and secure it with 4 screws on both sides.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img26.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 27: Mount Controller**
|
||||||
|
- Attach the motor controller on the back.
|
||||||
|
|
||||||
|
<div style="display: flex;">
|
||||||
|
<img src="../media/tutorial/img27.jpg" style="height:300px;">
|
||||||
|
<img src="../media/tutorial/img28.jpg" style="height:300px;">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
*Assembly complete – proceed to Leader arm assembly.*
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Leader Configuration
|
||||||
|
|
||||||
|
For the leader configuration, perform **Steps 1–23**. Make sure that you removed the motor gears from the motors.
|
||||||
|
|
||||||
|
**Step 24: Attach Leader Holder**
|
||||||
|
- Mount the leader holder onto the wrist and secure it with a screw.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img29.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 25: Attach Handle**
|
||||||
|
- Attach the handle to motor 5 using 4 screws.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img30.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 26: Install Gripper Motor**
|
||||||
|
- Insert the gripper motor, secure it with 3 screws on each side, attach a motor horn using a horn screw, and connect the motor wire.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img31.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 27: Attach Trigger**
|
||||||
|
- Attach the follower trigger with 4 screws.
|
||||||
|
|
||||||
|
<img src="../media/tutorial/img32.jpg" style="height:300px;">
|
||||||
|
|
||||||
|
**Step 28: Mount Controller**
|
||||||
|
- Attach the motor controller on the back.
|
||||||
|
|
||||||
|
<div style="display: flex;">
|
||||||
|
<img src="../media/tutorial/img27.jpg" style="height:300px;">
|
||||||
|
<img src="../media/tutorial/img28.jpg" style="height:300px;">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
*Assembly complete – proceed to calibration.*
|
||||||
|
|
||||||
|
|
||||||
## E. Calibrate
|
## E. Calibrate
|
||||||
|
|
||||||
|
@ -256,7 +455,7 @@ Next, you'll need to calibrate your SO-100 robot to ensure that the leader and f
|
||||||
You will need to move the follower arm to these positions sequentially:
|
You will need to move the follower arm to these positions sequentially:
|
||||||
|
|
||||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|---|---|---|
|
| ------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
| <img src="../media/so100/follower_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/so100/follower_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/so100/follower_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
| <img src="../media/so100/follower_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/so100/follower_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/so100/follower_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
Make sure both arms are connected and run this script to launch manual calibration:
|
Make sure both arms are connected and run this script to launch manual calibration:
|
||||||
|
@ -272,7 +471,7 @@ python lerobot/scripts/control_robot.py \
|
||||||
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||||
|
|
||||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|---|---|---|
|
| ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
Run this script to launch manual calibration:
|
Run this script to launch manual calibration:
|
||||||
|
@ -335,7 +534,7 @@ python lerobot/scripts/control_robot.py \
|
||||||
--control.push_to_hub=true
|
--control.push_to_hub=true
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`.
|
Note: You can resume recording by adding `--control.resume=true`.
|
||||||
|
|
||||||
## H. Visualize a dataset
|
## H. Visualize a dataset
|
||||||
|
|
||||||
|
@ -344,7 +543,7 @@ If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you c
|
||||||
echo ${HF_USER}/so100_test
|
echo ${HF_USER}/so100_test
|
||||||
```
|
```
|
||||||
|
|
||||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with:
|
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://127.0.0.1:9090` with the visualization tool):
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/visualize_dataset_html.py \
|
python lerobot/scripts/visualize_dataset_html.py \
|
||||||
--repo-id ${HF_USER}/so100_test \
|
--repo-id ${HF_USER}/so100_test \
|
||||||
|
@ -363,8 +562,6 @@ python lerobot/scripts/control_robot.py \
|
||||||
--control.episode=0
|
--control.episode=0
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
|
|
||||||
|
|
||||||
## J. Train a policy
|
## J. Train a policy
|
||||||
|
|
||||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||||
|
@ -374,20 +571,25 @@ python lerobot/scripts/train.py \
|
||||||
--policy.type=act \
|
--policy.type=act \
|
||||||
--output_dir=outputs/train/act_so100_test \
|
--output_dir=outputs/train/act_so100_test \
|
||||||
--job_name=act_so100_test \
|
--job_name=act_so100_test \
|
||||||
--device=cuda \
|
--policy.device=cuda \
|
||||||
--wandb.enable=true
|
--wandb.enable=true
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
|
|
||||||
|
|
||||||
Let's explain it:
|
Let's explain it:
|
||||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`.
|
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`.
|
||||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||||
|
|
||||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||||
|
|
||||||
|
To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so100_test` policy:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
--config_path=outputs/train/act_so100_test/checkpoints/last/pretrained_model/train_config.json \
|
||||||
|
--resume=true
|
||||||
|
```
|
||||||
|
|
||||||
## K. Evaluate your policy
|
## K. Evaluate your policy
|
||||||
|
|
||||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||||
|
@ -416,4 +618,4 @@ As you can see, it's almost the same command as previously used to record your t
|
||||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> If you have any questions or need help, please reach out on Discord in the channel [`#so100-arm`](https://discord.com/channels/1216765309076115607/1237741463832363039).
|
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb) in the channel [`#so100-arm`](https://discord.com/channels/1216765309076115607/1237741463832363039).
|
||||||
|
|
|
@ -0,0 +1,585 @@
|
||||||
|
# Using the [LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi) Robot with LeRobot
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [A. Source the parts](#a-source-the-parts)
|
||||||
|
- [B. Install software Pi](#b-install-software-on-pi)
|
||||||
|
- [C. Setup LeRobot laptop/pc](#c-install-lerobot-on-laptop)
|
||||||
|
- [D. Assemble the arms](#d-assembly)
|
||||||
|
- [E. Calibrate](#e-calibration)
|
||||||
|
- [F. Teleoperate](#f-teleoperate)
|
||||||
|
- [G. Record a dataset](#g-record-a-dataset)
|
||||||
|
- [H. Visualize a dataset](#h-visualize-a-dataset)
|
||||||
|
- [I. Replay an episode](#i-replay-an-episode)
|
||||||
|
- [J. Train a policy](#j-train-a-policy)
|
||||||
|
- [K. Evaluate your policy](#k-evaluate-your-policy)
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb) in the channel [`#mobile-so-100-arm`](https://discord.com/channels/1216765309076115607/1318390825528332371).
|
||||||
|
|
||||||
|
## A. Source the parts
|
||||||
|
|
||||||
|
Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts, and advice if it's your first time printing or if you don't own a 3D printer.
|
||||||
|
|
||||||
|
Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||||
|
|
||||||
|
### Wired version
|
||||||
|
If you have the **wired** LeKiwi version you can skip the installation of the Raspberry Pi and setting up SSH. You can also run all commands directly on your PC for both the LeKiwi scripts and the leader arm scripts for teleoperating.
|
||||||
|
|
||||||
|
## B. Install software on Pi
|
||||||
|
Now we have to setup the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board.
|
||||||
|
|
||||||
|
### Install OS
|
||||||
|
For setting up the Raspberry Pi and its SD-card see: [Setup PI](https://www.raspberrypi.com/documentation/computers/getting-started.html). Here is explained how to download the [Imager](https://www.raspberrypi.com/software/) to install Raspberry Pi OS or Ubuntu.
|
||||||
|
|
||||||
|
### Setup SSH
|
||||||
|
After setting up your Pi, you should enable and setup [SSH](https://www.raspberrypi.com/news/coding-on-raspberry-pi-remotely-with-visual-studio-code/) (Secure Shell Protocol) so you can login into the Pi from your laptop without requiring a screen, keyboard and mouse in the Pi. A great tutorial on how to do this can be found [here](https://www.raspberrypi.com/documentation/computers/remote-access.html#ssh). Logging into your Pi can be done in your Command Prompt (cmd) or if you use VSCode you can use [this](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) extension.
|
||||||
|
|
||||||
|
### Install LeRobot
|
||||||
|
|
||||||
|
On your Raspberry Pi:
|
||||||
|
|
||||||
|
#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install):
|
||||||
|
|
||||||
|
#### 2. Restart shell
|
||||||
|
Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell
|
||||||
|
|
||||||
|
#### 3. Create and activate a fresh conda environment for lerobot
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>Video install instructions</strong></summary>
|
||||||
|
|
||||||
|
<video src="https://github.com/user-attachments/assets/17172d3b-3b64-4b80-9cf1-b2b7c5cbd236"></video>
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda create -y -n lerobot python=3.10
|
||||||
|
```
|
||||||
|
|
||||||
|
Then activate your conda environment (do this each time you open a shell to use lerobot!):
|
||||||
|
```bash
|
||||||
|
conda activate lerobot
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. Clone LeRobot:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||||
|
```bash
|
||||||
|
cd ~/lerobot && pip install -e ".[feetech]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## C. Install LeRobot on laptop
|
||||||
|
If you already have install LeRobot on your laptop you can skip this step, otherwise please follow along as we do the same steps we did on the Pi.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> We use the Command Prompt (cmd) quite a lot. If you are not comfortable using the cmd or want to brush up using the command line you can have a look here: [Command line crash course](https://developer.mozilla.org/en-US/docs/Learn_web_development/Getting_started/Environment_setup/Command_line)
|
||||||
|
|
||||||
|
On your computer:
|
||||||
|
|
||||||
|
#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install):
|
||||||
|
|
||||||
|
#### 2. Restart shell
|
||||||
|
Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell
|
||||||
|
|
||||||
|
#### 3. Create and activate a fresh conda environment for lerobot
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>Video install instructions</strong></summary>
|
||||||
|
|
||||||
|
<video src="https://github.com/user-attachments/assets/17172d3b-3b64-4b80-9cf1-b2b7c5cbd236"></video>
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda create -y -n lerobot python=3.10
|
||||||
|
```
|
||||||
|
|
||||||
|
Then activate your conda environment (do this each time you open a shell to use lerobot!):
|
||||||
|
```bash
|
||||||
|
conda activate lerobot
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. Clone LeRobot:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||||
|
```bash
|
||||||
|
cd ~/lerobot && pip install -e ".[feetech]"
|
||||||
|
```
|
||||||
|
|
||||||
|
*EXTRA: For Linux only (not Mac)*: install extra dependencies for recording datasets:
|
||||||
|
```bash
|
||||||
|
conda install -y -c conda-forge ffmpeg
|
||||||
|
pip uninstall -y opencv-python
|
||||||
|
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||||
|
```
|
||||||
|
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:.
|
||||||
|
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
|
||||||
|
|
||||||
|
# D. Assembly
|
||||||
|
|
||||||
|
First we will assemble the two SO100 arms. One to attach to the mobile base and one for teleoperation. Then we will assemble the mobile base.
|
||||||
|
|
||||||
|
## SO100 Arms
|
||||||
|
### Configure motors
|
||||||
|
The instructions for configuring the motors can be found [Here](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md#c-configure-the-motors) in step C of the SO100 tutorial. Besides the ID's for the arm motors we also need to set the motor ID's for the mobile base. These needs to be in a specific order to work. Below an image of the motor ID's and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ID's for the wheels are 7, 8 and 9.
|
||||||
|
|
||||||
|
<img src="../media/lekiwi/motor_ids.webp?raw=true" alt="Motor ID's for mobile robot" title="Motor ID's for mobile robot" width="60%">
|
||||||
|
|
||||||
|
### Assemble arms
|
||||||
|
[Assemble arms instruction](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md#d-assemble-the-arms)
|
||||||
|
|
||||||
|
## Mobile base (LeKiwi)
|
||||||
|
[Assemble LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi)
|
||||||
|
|
||||||
|
### Update config
|
||||||
|
Both config files on the LeKiwi LeRobot and on the laptop should be the same. First we should find the Ip address of the Raspberry Pi of the mobile manipulator. This is the same Ip address used in SSH. We also need the usb port of the control board of the leader arm on the laptop and the port of the control board on LeKiwi. We can find these ports with the following script.
|
||||||
|
|
||||||
|
#### a. Run the script to find port
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>Video finding port</strong></summary>
|
||||||
|
<video src="https://github.com/user-attachments/assets/4a21a14d-2046-4805-93c4-ee97a30ba33f"></video>
|
||||||
|
<video src="https://github.com/user-attachments/assets/1cc3aecf-c16d-4ff9-aec7-8c175afbbce2"></video>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
To find the port for each bus servo adapter, run the utility script:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/find_motors_bus_port.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### b. Example outputs
|
||||||
|
|
||||||
|
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||||
|
```
|
||||||
|
Finding all available ports for the MotorBus.
|
||||||
|
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||||
|
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||||
|
|
||||||
|
[...Disconnect leader arm and press Enter...]
|
||||||
|
|
||||||
|
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||||
|
Reconnect the usb cable.
|
||||||
|
```
|
||||||
|
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||||
|
```
|
||||||
|
Finding all available ports for the MotorBus.
|
||||||
|
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||||
|
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||||
|
|
||||||
|
[...Disconnect follower arm and press Enter...]
|
||||||
|
|
||||||
|
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||||
|
Reconnect the usb cable.
|
||||||
|
```
|
||||||
|
|
||||||
|
#### c. Troubleshooting
|
||||||
|
On Linux, you might need to give access to the USB ports by running:
|
||||||
|
```bash
|
||||||
|
sudo chmod 666 /dev/ttyACM0
|
||||||
|
sudo chmod 666 /dev/ttyACM1
|
||||||
|
```
|
||||||
|
|
||||||
|
#### d. Update config file
|
||||||
|
|
||||||
|
IMPORTANTLY: Now that you have your ports of leader and follower arm and ip address of the mobile-so100, update the **ip** in Network configuration, **port** in leader_arms and **port** in lekiwi. In the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py) file. Where you will find something like:
|
||||||
|
```python
|
||||||
|
@RobotConfig.register_subclass("lekiwi")
|
||||||
|
@dataclass
|
||||||
|
class LeKiwiRobotConfig(RobotConfig):
|
||||||
|
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||||
|
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||||
|
# the number of motors in your follower arms.
|
||||||
|
max_relative_target: int | None = None
|
||||||
|
|
||||||
|
# Network Configuration
|
||||||
|
ip: str = "172.17.133.91"
|
||||||
|
port: int = 5555
|
||||||
|
video_port: int = 5556
|
||||||
|
|
||||||
|
cameras: dict[str, CameraConfig] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"mobile": OpenCVCameraConfig(camera_index="/dev/video0", fps=30, width=640, height=480),
|
||||||
|
"mobile2": OpenCVCameraConfig(camera_index="/dev/video2", fps=30, width=640, height=480),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
calibration_dir: str = ".cache/calibration/lekiwi"
|
||||||
|
|
||||||
|
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"main": FeetechMotorsBusConfig(
|
||||||
|
port="/dev/tty.usbmodem585A0077581",
|
||||||
|
motors={
|
||||||
|
# name: (index, model)
|
||||||
|
"shoulder_pan": [1, "sts3215"],
|
||||||
|
"shoulder_lift": [2, "sts3215"],
|
||||||
|
"elbow_flex": [3, "sts3215"],
|
||||||
|
"wrist_flex": [4, "sts3215"],
|
||||||
|
"wrist_roll": [5, "sts3215"],
|
||||||
|
"gripper": [6, "sts3215"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"main": FeetechMotorsBusConfig(
|
||||||
|
port="/dev/ttyACM0",
|
||||||
|
motors={
|
||||||
|
# name: (index, model)
|
||||||
|
"shoulder_pan": [1, "sts3215"],
|
||||||
|
"shoulder_lift": [2, "sts3215"],
|
||||||
|
"elbow_flex": [3, "sts3215"],
|
||||||
|
"wrist_flex": [4, "sts3215"],
|
||||||
|
"wrist_roll": [5, "sts3215"],
|
||||||
|
"gripper": [6, "sts3215"],
|
||||||
|
"left_wheel": (7, "sts3215"),
|
||||||
|
"back_wheel": (8, "sts3215"),
|
||||||
|
"right_wheel": (9, "sts3215"),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
teleop_keys: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
# Movement
|
||||||
|
"forward": "w",
|
||||||
|
"backward": "s",
|
||||||
|
"left": "a",
|
||||||
|
"right": "d",
|
||||||
|
"rotate_left": "z",
|
||||||
|
"rotate_right": "x",
|
||||||
|
# Speed control
|
||||||
|
"speed_up": "r",
|
||||||
|
"speed_down": "f",
|
||||||
|
# quit teleop
|
||||||
|
"quit": "q",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock: bool = False
|
||||||
|
```
|
||||||
|
|
||||||
|
## Wired version
|
||||||
|
|
||||||
|
For the wired LeKiwi version your configured IP address should refer to your own laptop (127.0.0.1), because leader arm and LeKiwi are in this case connected to own laptop. Below and example configuration for this wired setup:
|
||||||
|
```python
|
||||||
|
@RobotConfig.register_subclass("lekiwi")
|
||||||
|
@dataclass
|
||||||
|
class LeKiwiRobotConfig(RobotConfig):
|
||||||
|
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||||
|
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||||
|
# the number of motors in your follower arms.
|
||||||
|
max_relative_target: int | None = None
|
||||||
|
|
||||||
|
# Network Configuration
|
||||||
|
ip: str = "127.0.0.1"
|
||||||
|
port: int = 5555
|
||||||
|
video_port: int = 5556
|
||||||
|
|
||||||
|
cameras: dict[str, CameraConfig] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"front": OpenCVCameraConfig(
|
||||||
|
camera_index=0, fps=30, width=640, height=480, rotation=90
|
||||||
|
),
|
||||||
|
"wrist": OpenCVCameraConfig(
|
||||||
|
camera_index=1, fps=30, width=640, height=480, rotation=180
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
calibration_dir: str = ".cache/calibration/lekiwi"
|
||||||
|
|
||||||
|
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"main": FeetechMotorsBusConfig(
|
||||||
|
port="/dev/tty.usbmodem585A0077581",
|
||||||
|
motors={
|
||||||
|
# name: (index, model)
|
||||||
|
"shoulder_pan": [1, "sts3215"],
|
||||||
|
"shoulder_lift": [2, "sts3215"],
|
||||||
|
"elbow_flex": [3, "sts3215"],
|
||||||
|
"wrist_flex": [4, "sts3215"],
|
||||||
|
"wrist_roll": [5, "sts3215"],
|
||||||
|
"gripper": [6, "sts3215"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"main": FeetechMotorsBusConfig(
|
||||||
|
port="/dev/tty.usbmodem58760431061",
|
||||||
|
motors={
|
||||||
|
# name: (index, model)
|
||||||
|
"shoulder_pan": [1, "sts3215"],
|
||||||
|
"shoulder_lift": [2, "sts3215"],
|
||||||
|
"elbow_flex": [3, "sts3215"],
|
||||||
|
"wrist_flex": [4, "sts3215"],
|
||||||
|
"wrist_roll": [5, "sts3215"],
|
||||||
|
"gripper": [6, "sts3215"],
|
||||||
|
"left_wheel": (7, "sts3215"),
|
||||||
|
"back_wheel": (8, "sts3215"),
|
||||||
|
"right_wheel": (9, "sts3215"),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
teleop_keys: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
# Movement
|
||||||
|
"forward": "w",
|
||||||
|
"backward": "s",
|
||||||
|
"left": "a",
|
||||||
|
"right": "d",
|
||||||
|
"rotate_left": "z",
|
||||||
|
"rotate_right": "x",
|
||||||
|
# Speed control
|
||||||
|
"speed_up": "r",
|
||||||
|
"speed_down": "f",
|
||||||
|
# quit teleop
|
||||||
|
"quit": "q",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock: bool = False
|
||||||
|
```
|
||||||
|
|
||||||
|
# E. Calibration
|
||||||
|
Now we have to calibrate the leader arm and the follower arm. The wheel motors don't have to be calibrated.
|
||||||
|
|
||||||
|
|
||||||
|
### Calibrate follower arm (on mobile base)
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||||
|
|
||||||
|
You will need to move the follower arm to these positions sequentially:
|
||||||
|
|
||||||
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| <img src="../media/lekiwi/mobile_calib_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
|
Make sure the arm is connected to the Raspberry Pi and run this script (on the Raspberry Pi) to launch manual calibration:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_robot.py \
|
||||||
|
--robot.type=lekiwi \
|
||||||
|
--robot.cameras='{}' \
|
||||||
|
--control.type=calibrate \
|
||||||
|
--control.arms='["main_follower"]'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Wired version
|
||||||
|
If you have the **wired** LeKiwi version please run all commands including this calibration command on your laptop.
|
||||||
|
|
||||||
|
### Calibrate leader arm
|
||||||
|
Then to calibrate the leader arm (which is attached to the laptop/pc). You will need to move the leader arm to these positions sequentially:
|
||||||
|
|
||||||
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|
| ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
|
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
|
Run this script (on your laptop/pc) to launch manual calibration:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_robot.py \
|
||||||
|
--robot.type=lekiwi \
|
||||||
|
--robot.cameras='{}' \
|
||||||
|
--control.type=calibrate \
|
||||||
|
--control.arms='["main_leader"]'
|
||||||
|
```
|
||||||
|
|
||||||
|
# F. Teleoperate
|
||||||
|
To teleoperate SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_robot.py \
|
||||||
|
--robot.type=lekiwi \
|
||||||
|
--control.type=remote_robot
|
||||||
|
```
|
||||||
|
|
||||||
|
Then on your laptop, also run `conda activate lerobot` and this script:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_robot.py \
|
||||||
|
--robot.type=lekiwi \
|
||||||
|
--control.type=teleoperate \
|
||||||
|
--control.fps=30
|
||||||
|
```
|
||||||
|
|
||||||
|
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
|
||||||
|
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
|
||||||
|
| ---------- | ------------------ | ---------------------- |
|
||||||
|
| Fast | 0.4 | 90 |
|
||||||
|
| Medium | 0.25 | 60 |
|
||||||
|
| Slow | 0.1 | 30 |
|
||||||
|
|
||||||
|
|
||||||
|
| Key | Action |
|
||||||
|
| --- | -------------- |
|
||||||
|
| W | Move forward |
|
||||||
|
| A | Move left |
|
||||||
|
| S | Move backward |
|
||||||
|
| D | Move right |
|
||||||
|
| Z | Turn left |
|
||||||
|
| X | Turn right |
|
||||||
|
| R | Increase speed |
|
||||||
|
| F | Decrease speed |
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> If you use a different keyboard you can change the keys for each command in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
|
||||||
|
|
||||||
|
### Wired version
|
||||||
|
If you have the **wired** LeKiwi version please run all commands including both these teleoperation commands on your laptop.
|
||||||
|
|
||||||
|
## Troubleshoot communication
|
||||||
|
|
||||||
|
If you are having trouble connecting to the Mobile SO100, follow these steps to diagnose and resolve the issue.
|
||||||
|
|
||||||
|
### 1. Verify IP Address Configuration
|
||||||
|
Make sure that the correct ip for the Pi is set in the configuration file. To check the Raspberry Pi's IP address, run (on the Pi command line):
|
||||||
|
```bash
|
||||||
|
hostname -I
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Check if Pi is reachable from laptop/pc
|
||||||
|
Try pinging the Raspberry Pi from your laptop:
|
||||||
|
```bach
|
||||||
|
ping <your_pi_ip_address>
|
||||||
|
```
|
||||||
|
|
||||||
|
If the ping fails:
|
||||||
|
- Ensure the Pi is powered on and connected to the same network.
|
||||||
|
- Check if SSH is enabled on the Pi.
|
||||||
|
|
||||||
|
### 3. Try SSH connection
|
||||||
|
If you can't SSH into the Pi, it might not be properly connected. Use:
|
||||||
|
```bash
|
||||||
|
ssh <your_pi_user_name>@<your_pi_ip_address>
|
||||||
|
```
|
||||||
|
If you get a connection error:
|
||||||
|
- Ensure SSH is enabled on the Pi by running:
|
||||||
|
```bash
|
||||||
|
sudo raspi-config
|
||||||
|
```
|
||||||
|
Then navigate to: **Interfacing Options -> SSH** and enable it.
|
||||||
|
|
||||||
|
### 4. Same config file
|
||||||
|
Make sure the configuration file on both your laptop/pc and the Raspberry Pi is the same.
|
||||||
|
|
||||||
|
# G. Record a dataset
|
||||||
|
Once you're familiar with teleoperation, you can record your first dataset with LeKiwi.
|
||||||
|
|
||||||
|
To start the program on LeKiwi, SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_robot.py \
|
||||||
|
--robot.type=lekiwi \
|
||||||
|
--control.type=remote_robot
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||||
|
```bash
|
||||||
|
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||||
|
```
|
||||||
|
|
||||||
|
Store your Hugging Face repository name in a variable to run these commands:
|
||||||
|
```bash
|
||||||
|
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||||
|
echo $HF_USER
|
||||||
|
```
|
||||||
|
On your laptop then run this command to record 2 episodes and upload your dataset to the hub:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_robot.py \
|
||||||
|
--robot.type=lekiwi \
|
||||||
|
--control.type=record \
|
||||||
|
--control.fps=30 \
|
||||||
|
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||||
|
--control.repo_id=${HF_USER}/lekiwi_test \
|
||||||
|
--control.tags='["tutorial"]' \
|
||||||
|
--control.warmup_time_s=5 \
|
||||||
|
--control.episode_time_s=30 \
|
||||||
|
--control.reset_time_s=30 \
|
||||||
|
--control.num_episodes=2 \
|
||||||
|
--control.push_to_hub=true
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: You can resume recording by adding `--control.resume=true`.
|
||||||
|
|
||||||
|
### Wired version
|
||||||
|
If you have the **wired** LeKiwi version please run all commands including both these record dataset commands on your laptop.
|
||||||
|
|
||||||
|
# H. Visualize a dataset
|
||||||
|
|
||||||
|
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||||
|
```bash
|
||||||
|
echo ${HF_USER}/lekiwi_test
|
||||||
|
```
|
||||||
|
|
||||||
|
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://127.0.0.1:9090` with the visualization tool):
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/visualize_dataset_html.py \
|
||||||
|
--repo-id ${HF_USER}/lekiwi_test \
|
||||||
|
--local-files-only 1
|
||||||
|
```
|
||||||
|
|
||||||
|
# I. Replay an episode
|
||||||
|
Now try to replay the first episode on your robot:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_robot.py \
|
||||||
|
--robot.type=lekiwi \
|
||||||
|
--control.type=replay \
|
||||||
|
--control.fps=30 \
|
||||||
|
--control.repo_id=${HF_USER}/lekiwi_test \
|
||||||
|
--control.episode=0
|
||||||
|
```
|
||||||
|
|
||||||
|
## J. Train a policy
|
||||||
|
|
||||||
|
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
--dataset.repo_id=${HF_USER}/lekiwi_test \
|
||||||
|
--policy.type=act \
|
||||||
|
--output_dir=outputs/train/act_lekiwi_test \
|
||||||
|
--job_name=act_lekiwi_test \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--wandb.enable=true
|
||||||
|
```
|
||||||
|
|
||||||
|
Let's explain it:
|
||||||
|
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/lekiwi_test`.
|
||||||
|
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||||
|
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||||
|
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||||
|
|
||||||
|
Training should take several hours. You will find checkpoints in `outputs/train/act_lekiwi_test/checkpoints`.
|
||||||
|
|
||||||
|
## K. Evaluate your policy
|
||||||
|
|
||||||
|
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_robot.py \
|
||||||
|
--robot.type=lekiwi \
|
||||||
|
--control.type=record \
|
||||||
|
--control.fps=30 \
|
||||||
|
--control.single_task="Drive to the red block and pick it up" \
|
||||||
|
--control.repo_id=${HF_USER}/eval_act_lekiwi_test \
|
||||||
|
--control.tags='["tutorial"]' \
|
||||||
|
--control.warmup_time_s=5 \
|
||||||
|
--control.episode_time_s=30 \
|
||||||
|
--control.reset_time_s=30 \
|
||||||
|
--control.num_episodes=10 \
|
||||||
|
--control.push_to_hub=true \
|
||||||
|
--control.policy.path=outputs/train/act_lekiwi_test/checkpoints/last/pretrained_model
|
||||||
|
```
|
||||||
|
|
||||||
|
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||||
|
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_lekiwi_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_lekiwi_test`).
|
||||||
|
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_lekiwi_test`).
|
|
@ -2,7 +2,7 @@ This tutorial explains how to use [Moss v1](https://github.com/jess-moss/moss-ro
|
||||||
|
|
||||||
## Source the parts
|
## Source the parts
|
||||||
|
|
||||||
Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
|
Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials with link to source the parts, as well as the instructions to 3D print the parts and advice if it's your first time printing or if you don't own a 3D printer already.
|
||||||
|
|
||||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||||
|
|
||||||
|
@ -177,7 +177,7 @@ Next, you'll need to calibrate your Moss v1 robot to ensure that the leader and
|
||||||
You will need to move the follower arm to these positions sequentially:
|
You will need to move the follower arm to these positions sequentially:
|
||||||
|
|
||||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|---|---|---|
|
| ------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| <img src="../media/moss/follower_zero.webp?raw=true" alt="Moss v1 follower arm zero position" title="Moss v1 follower arm zero position" style="width:100%;"> | <img src="../media/moss/follower_rotated.webp?raw=true" alt="Moss v1 follower arm rotated position" title="Moss v1 follower arm rotated position" style="width:100%;"> | <img src="../media/moss/follower_rest.webp?raw=true" alt="Moss v1 follower arm rest position" title="Moss v1 follower arm rest position" style="width:100%;"> |
|
| <img src="../media/moss/follower_zero.webp?raw=true" alt="Moss v1 follower arm zero position" title="Moss v1 follower arm zero position" style="width:100%;"> | <img src="../media/moss/follower_rotated.webp?raw=true" alt="Moss v1 follower arm rotated position" title="Moss v1 follower arm rotated position" style="width:100%;"> | <img src="../media/moss/follower_rest.webp?raw=true" alt="Moss v1 follower arm rest position" title="Moss v1 follower arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
Make sure both arms are connected and run this script to launch manual calibration:
|
Make sure both arms are connected and run this script to launch manual calibration:
|
||||||
|
@ -193,7 +193,7 @@ python lerobot/scripts/control_robot.py \
|
||||||
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||||
|
|
||||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|---|---|---|
|
| ------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| <img src="../media/moss/leader_zero.webp?raw=true" alt="Moss v1 leader arm zero position" title="Moss v1 leader arm zero position" style="width:100%;"> | <img src="../media/moss/leader_rotated.webp?raw=true" alt="Moss v1 leader arm rotated position" title="Moss v1 leader arm rotated position" style="width:100%;"> | <img src="../media/moss/leader_rest.webp?raw=true" alt="Moss v1 leader arm rest position" title="Moss v1 leader arm rest position" style="width:100%;"> |
|
| <img src="../media/moss/leader_zero.webp?raw=true" alt="Moss v1 leader arm zero position" title="Moss v1 leader arm zero position" style="width:100%;"> | <img src="../media/moss/leader_rotated.webp?raw=true" alt="Moss v1 leader arm rotated position" title="Moss v1 leader arm rotated position" style="width:100%;"> | <img src="../media/moss/leader_rest.webp?raw=true" alt="Moss v1 leader arm rest position" title="Moss v1 leader arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
Run this script to launch manual calibration:
|
Run this script to launch manual calibration:
|
||||||
|
@ -256,7 +256,7 @@ python lerobot/scripts/control_robot.py \
|
||||||
--control.push_to_hub=true
|
--control.push_to_hub=true
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`.
|
Note: You can resume recording by adding `--control.resume=true`.
|
||||||
|
|
||||||
## Visualize a dataset
|
## Visualize a dataset
|
||||||
|
|
||||||
|
@ -284,8 +284,6 @@ python lerobot/scripts/control_robot.py \
|
||||||
--control.episode=0
|
--control.episode=0
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
|
|
||||||
|
|
||||||
## Train a policy
|
## Train a policy
|
||||||
|
|
||||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||||
|
@ -295,16 +293,14 @@ python lerobot/scripts/train.py \
|
||||||
--policy.type=act \
|
--policy.type=act \
|
||||||
--output_dir=outputs/train/act_moss_test \
|
--output_dir=outputs/train/act_moss_test \
|
||||||
--job_name=act_moss_test \
|
--job_name=act_moss_test \
|
||||||
--device=cuda \
|
--policy.device=cuda \
|
||||||
--wandb.enable=true
|
--wandb.enable=true
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
|
|
||||||
|
|
||||||
Let's explain it:
|
Let's explain it:
|
||||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`.
|
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`.
|
||||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||||
|
|
||||||
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
|
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
|
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
|
||||||
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
||||||
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
||||||
|
@ -30,7 +44,7 @@ pretrained_policy_path = "lerobot/diffusion_pusht"
|
||||||
# OR a path to a local outputs/train folder.
|
# OR a path to a local outputs/train folder.
|
||||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||||
|
|
||||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path, map_location=device)
|
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||||
|
|
||||||
# Initialize evaluation environment to render two observation types:
|
# Initialize evaluation environment to render two observation types:
|
||||||
# an image of the scene and state/position of the agent. The environment
|
# an image of the scene and state/position of the agent. The environment
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
||||||
|
|
||||||
Once you have trained a model with this script, you can try to evaluate it on
|
Once you have trained a model with this script, you can try to evaluate it on
|
||||||
|
@ -85,7 +99,7 @@ def main():
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||||
loss, _ = policy.forward(batch)
|
loss, _ = policy.forward(batch)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run.
|
This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run.
|
||||||
> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--device=cpu` (`--device=mps` respectively). However, be advised that the code executes much slower on cpu.
|
> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--policy.device=cpu` (`--policy.device=mps` respectively). However, be advised that the code executes much slower on cpu.
|
||||||
|
|
||||||
|
|
||||||
## The training script
|
## The training script
|
||||||
|
|
|
@ -387,18 +387,18 @@ When you connect your robot for the first time, the [`ManipulatorRobot`](../lero
|
||||||
Here are the positions you'll move the follower arm to:
|
Here are the positions you'll move the follower arm to:
|
||||||
|
|
||||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|---|---|---|
|
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| <img src="../media/koch/follower_zero.webp?raw=true" alt="Koch v1.1 follower arm zero position" title="Koch v1.1 follower arm zero position" style="width:100%;"> | <img src="../media/koch/follower_rotated.webp?raw=true" alt="Koch v1.1 follower arm rotated position" title="Koch v1.1 follower arm rotated position" style="width:100%;"> | <img src="../media/koch/follower_rest.webp?raw=true" alt="Koch v1.1 follower arm rest position" title="Koch v1.1 follower arm rest position" style="width:100%;"> |
|
| <img src="../media/koch/follower_zero.webp?raw=true" alt="Koch v1.1 follower arm zero position" title="Koch v1.1 follower arm zero position" style="width:100%;"> | <img src="../media/koch/follower_rotated.webp?raw=true" alt="Koch v1.1 follower arm rotated position" title="Koch v1.1 follower arm rotated position" style="width:100%;"> | <img src="../media/koch/follower_rest.webp?raw=true" alt="Koch v1.1 follower arm rest position" title="Koch v1.1 follower arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
And here are the corresponding positions for the leader arm:
|
And here are the corresponding positions for the leader arm:
|
||||||
|
|
||||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|---|---|---|
|
| ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| <img src="../media/koch/leader_zero.webp?raw=true" alt="Koch v1.1 leader arm zero position" title="Koch v1.1 leader arm zero position" style="width:100%;"> | <img src="../media/koch/leader_rotated.webp?raw=true" alt="Koch v1.1 leader arm rotated position" title="Koch v1.1 leader arm rotated position" style="width:100%;"> | <img src="../media/koch/leader_rest.webp?raw=true" alt="Koch v1.1 leader arm rest position" title="Koch v1.1 leader arm rest position" style="width:100%;"> |
|
| <img src="../media/koch/leader_zero.webp?raw=true" alt="Koch v1.1 leader arm zero position" title="Koch v1.1 leader arm zero position" style="width:100%;"> | <img src="../media/koch/leader_rotated.webp?raw=true" alt="Koch v1.1 leader arm rotated position" title="Koch v1.1 leader arm rotated position" style="width:100%;"> | <img src="../media/koch/leader_rest.webp?raw=true" alt="Koch v1.1 leader arm rest position" title="Koch v1.1 leader arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
|
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
|
||||||
|
|
||||||
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to mesure if the values changed negatively or positively.
|
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to measure if the values changed negatively or positively.
|
||||||
|
|
||||||
Finally, the rest position ensures that the follower and leader arms are roughly aligned after calibration, preventing sudden movements that could damage the motors when starting teleoperation.
|
Finally, the rest position ensures that the follower and leader arms are roughly aligned after calibration, preventing sudden movements that could damage the motors when starting teleoperation.
|
||||||
|
|
||||||
|
@ -626,7 +626,7 @@ Finally, run this code to instantiate and connectyour camera:
|
||||||
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
||||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||||
|
|
||||||
camera_config = OpenCVCameraConfig(camera_index=0)
|
config = OpenCVCameraConfig(camera_index=0)
|
||||||
camera = OpenCVCamera(config)
|
camera = OpenCVCamera(config)
|
||||||
camera.connect()
|
camera.connect()
|
||||||
color_image = camera.read()
|
color_image = camera.read()
|
||||||
|
@ -663,11 +663,12 @@ camera.disconnect()
|
||||||
|
|
||||||
**Instantiate your robot with cameras**
|
**Instantiate your robot with cameras**
|
||||||
|
|
||||||
Additionaly, you can set up your robot to work with your cameras.
|
Additionally, you can set up your robot to work with your cameras.
|
||||||
|
|
||||||
Modify the following Python code with the appropriate camera names and configurations:
|
Modify the following Python code with the appropriate camera names and configurations:
|
||||||
```python
|
```python
|
||||||
robot = ManipulatorRobot(
|
robot = ManipulatorRobot(
|
||||||
|
KochRobotConfig(
|
||||||
leader_arms={"main": leader_arm},
|
leader_arms={"main": leader_arm},
|
||||||
follower_arms={"main": follower_arm},
|
follower_arms={"main": follower_arm},
|
||||||
calibration_dir=".cache/calibration/koch",
|
calibration_dir=".cache/calibration/koch",
|
||||||
|
@ -676,6 +677,7 @@ robot = ManipulatorRobot(
|
||||||
"phone": OpenCVCameraConfig(1, fps=30, width=640, height=480),
|
"phone": OpenCVCameraConfig(1, fps=30, width=640, height=480),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
)
|
||||||
robot.connect()
|
robot.connect()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -711,7 +713,7 @@ python lerobot/scripts/control_robot.py \
|
||||||
|
|
||||||
You will see a lot of lines appearing like this one:
|
You will see a lot of lines appearing like this one:
|
||||||
```
|
```
|
||||||
INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtRfoll: 0.19 (5239.0hz)
|
INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtWfoll: 0.19 (5239.0hz)
|
||||||
```
|
```
|
||||||
|
|
||||||
It contains
|
It contains
|
||||||
|
@ -768,7 +770,7 @@ You can use the `record` function from [`lerobot/scripts/control_robot.py`](../l
|
||||||
1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of each episode recording.
|
1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of each episode recording.
|
||||||
2. Video streams from cameras are displayed in window so that you can verify them.
|
2. Video streams from cameras are displayed in window so that you can verify them.
|
||||||
3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--control.push_to_hub=false` is provided).
|
3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--control.push_to_hub=false` is provided).
|
||||||
4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You might need to add `--control.local_files_only=true` if your dataset was not uploaded to hugging face hub. Also you will need to manually delete the dataset directory to start recording from scratch.
|
4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You will need to manually delete the dataset directory if you want to start recording from scratch.
|
||||||
5. Set the flow of data recording using command line arguments:
|
5. Set the flow of data recording using command line arguments:
|
||||||
- `--control.warmup_time_s=10` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default).
|
- `--control.warmup_time_s=10` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default).
|
||||||
- `--control.episode_time_s=60` defines the number of seconds for data recording for each episode (60 seconds by default).
|
- `--control.episode_time_s=60` defines the number of seconds for data recording for each episode (60 seconds by default).
|
||||||
|
@ -823,8 +825,8 @@ It contains:
|
||||||
- `dtRlead: 5.06 (197.5hz)` which is the delta time of reading the present position of the leader arm.
|
- `dtRlead: 5.06 (197.5hz)` which is the delta time of reading the present position of the leader arm.
|
||||||
- `dtWfoll: 0.25 (3963.7hz)` which is the delta time of writing the goal position on the follower arm ; writing is asynchronous so it takes less time than reading.
|
- `dtWfoll: 0.25 (3963.7hz)` which is the delta time of writing the goal position on the follower arm ; writing is asynchronous so it takes less time than reading.
|
||||||
- `dtRfoll: 6.22 (160.7hz)` which is the delta time of reading the present position on the follower arm.
|
- `dtRfoll: 6.22 (160.7hz)` which is the delta time of reading the present position on the follower arm.
|
||||||
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchrously.
|
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchronously.
|
||||||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchrously.
|
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
|
||||||
|
|
||||||
Troubleshooting:
|
Troubleshooting:
|
||||||
- On Linux, if you encounter a hanging issue when using cameras, uninstall opencv and re-install it with conda:
|
- On Linux, if you encounter a hanging issue when using cameras, uninstall opencv and re-install it with conda:
|
||||||
|
@ -844,7 +846,7 @@ At the end of data recording, your dataset will be uploaded on your Hugging Face
|
||||||
echo https://huggingface.co/datasets/${HF_USER}/koch_test
|
echo https://huggingface.co/datasets/${HF_USER}/koch_test
|
||||||
```
|
```
|
||||||
|
|
||||||
### b. Advices for recording dataset
|
### b. Advice for recording dataset
|
||||||
|
|
||||||
Once you're comfortable with data recording, it's time to create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings.
|
Once you're comfortable with data recording, it's time to create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings.
|
||||||
|
|
||||||
|
@ -883,8 +885,6 @@ python lerobot/scripts/control_robot.py \
|
||||||
--control.episode=0
|
--control.episode=0
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: You might need to add `--control.local_files_only=true` if your dataset was not uploaded to hugging face hub.
|
|
||||||
|
|
||||||
Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com).
|
Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com).
|
||||||
|
|
||||||
## 4. Train a policy on your data
|
## 4. Train a policy on your data
|
||||||
|
@ -898,16 +898,14 @@ python lerobot/scripts/train.py \
|
||||||
--policy.type=act \
|
--policy.type=act \
|
||||||
--output_dir=outputs/train/act_koch_test \
|
--output_dir=outputs/train/act_koch_test \
|
||||||
--job_name=act_koch_test \
|
--job_name=act_koch_test \
|
||||||
--device=cuda \
|
--policy.device=cuda \
|
||||||
--wandb.enable=true
|
--wandb.enable=true
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: You might need to add `--dataset.local_files_only=true` if your dataset was not uploaded to hugging face hub.
|
|
||||||
|
|
||||||
Let's explain it:
|
Let's explain it:
|
||||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`.
|
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`.
|
||||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||||
|
|
||||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||||
|
|
|
@ -98,7 +98,7 @@ python lerobot/scripts/control_robot.py \
|
||||||
```
|
```
|
||||||
This is equivalent to running `stretch_robot_home.py`
|
This is equivalent to running `stretch_robot_home.py`
|
||||||
|
|
||||||
> **Note:** If you run any of the LeRobot scripts below and Stretch is not poperly homed, it will automatically home/calibrate first.
|
> **Note:** If you run any of the LeRobot scripts below and Stretch is not properly homed, it will automatically home/calibrate first.
|
||||||
|
|
||||||
**Teleoperate**
|
**Teleoperate**
|
||||||
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
|
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
|
||||||
|
|
|
@ -2,7 +2,7 @@ This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.tro
|
||||||
|
|
||||||
## Setup
|
## Setup
|
||||||
|
|
||||||
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
|
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/2.0/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
|
||||||
|
|
||||||
|
|
||||||
## Install LeRobot
|
## Install LeRobot
|
||||||
|
@ -135,14 +135,14 @@ python lerobot/scripts/train.py \
|
||||||
--policy.type=act \
|
--policy.type=act \
|
||||||
--output_dir=outputs/train/act_aloha_test \
|
--output_dir=outputs/train/act_aloha_test \
|
||||||
--job_name=act_aloha_test \
|
--job_name=act_aloha_test \
|
||||||
--device=cuda \
|
--policy.device=cuda \
|
||||||
--wandb.enable=true
|
--wandb.enable=true
|
||||||
```
|
```
|
||||||
|
|
||||||
Let's explain it:
|
Let's explain it:
|
||||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`.
|
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`.
|
||||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||||
|
|
||||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||||
|
@ -172,10 +172,10 @@ python lerobot/scripts/control_robot.py \
|
||||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`).
|
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`).
|
||||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`).
|
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`).
|
||||||
3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constent 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
|
3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constant 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
|
||||||
|
|
||||||
## More
|
## More
|
||||||
|
|
||||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explaination.
|
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explanation.
|
||||||
|
|
||||||
If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`.
|
If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`.
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
||||||
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
|
"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
|
||||||
|
|
||||||
This technique can be useful for debugging and testing purposes, as well as identifying whether a policy
|
This technique can be useful for debugging and testing purposes, as well as identifying whether a policy
|
||||||
|
|
|
@ -1,10 +1,25 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||||
|
|
||||||
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
|
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
|
||||||
|
@ -89,9 +104,9 @@ def calculate_coverage(zarr_data):
|
||||||
|
|
||||||
num_frames = len(block_pos)
|
num_frames = len(block_pos)
|
||||||
|
|
||||||
coverage = np.zeros((num_frames,))
|
coverage = np.zeros((num_frames,), dtype=np.float32)
|
||||||
# 8 keypoints with 2 coords each
|
# 8 keypoints with 2 coords each
|
||||||
keypoints = np.zeros((num_frames, 16))
|
keypoints = np.zeros((num_frames, 16), dtype=np.float32)
|
||||||
|
|
||||||
# Set x, y, theta (in radians)
|
# Set x, y, theta (in radians)
|
||||||
goal_pos_angle = np.array([256, 256, np.pi / 4])
|
goal_pos_angle = np.array([256, 256, np.pi / 4])
|
||||||
|
@ -117,7 +132,7 @@ def calculate_coverage(zarr_data):
|
||||||
intersection_area = goal_geom.intersection(block_geom).area
|
intersection_area = goal_geom.intersection(block_geom).area
|
||||||
goal_area = goal_geom.area
|
goal_area = goal_geom.area
|
||||||
coverage[i] = intersection_area / goal_area
|
coverage[i] = intersection_area / goal_area
|
||||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
keypoints[i] = PushTEnv.get_keypoints(block_shapes).flatten()
|
||||||
|
|
||||||
return coverage, keypoints
|
return coverage, keypoints
|
||||||
|
|
||||||
|
@ -134,8 +149,8 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||||
if mode not in ["video", "image", "keypoints"]:
|
if mode not in ["video", "image", "keypoints"]:
|
||||||
raise ValueError(mode)
|
raise ValueError(mode)
|
||||||
|
|
||||||
if (LEROBOT_HOME / repo_id).exists():
|
if (HF_LEROBOT_HOME / repo_id).exists():
|
||||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
|
||||||
|
|
||||||
if not raw_dir.exists():
|
if not raw_dir.exists():
|
||||||
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
||||||
|
@ -148,6 +163,10 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||||
action = zarr_data["action"][:]
|
action = zarr_data["action"][:]
|
||||||
image = zarr_data["img"] # (b, h, w, c)
|
image = zarr_data["img"] # (b, h, w, c)
|
||||||
|
|
||||||
|
if image.dtype == np.float32 and image.max() == np.float32(255):
|
||||||
|
# HACK: images are loaded as float32 but they actually encode uint8 data
|
||||||
|
image = image.astype(np.uint8)
|
||||||
|
|
||||||
episode_data_index = {
|
episode_data_index = {
|
||||||
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
|
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
|
||||||
"to": zarr_data.meta["episode_ends"],
|
"to": zarr_data.meta["episode_ends"],
|
||||||
|
@ -175,28 +194,30 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||||
|
|
||||||
for frame_idx in range(num_frames):
|
for frame_idx in range(num_frames):
|
||||||
i = from_idx + frame_idx
|
i = from_idx + frame_idx
|
||||||
|
idx = i + (frame_idx < num_frames - 1)
|
||||||
frame = {
|
frame = {
|
||||||
"action": torch.from_numpy(action[i]),
|
"action": action[i],
|
||||||
# Shift reward and success by +1 until the last item of the episode
|
# Shift reward and success by +1 until the last item of the episode
|
||||||
"next.reward": reward[i + (frame_idx < num_frames - 1)],
|
"next.reward": reward[idx : idx + 1],
|
||||||
"next.success": success[i + (frame_idx < num_frames - 1)],
|
"next.success": success[idx : idx + 1],
|
||||||
|
"task": PUSHT_TASK,
|
||||||
}
|
}
|
||||||
|
|
||||||
frame["observation.state"] = torch.from_numpy(agent_pos[i])
|
frame["observation.state"] = agent_pos[i]
|
||||||
|
|
||||||
if mode == "keypoints":
|
if mode == "keypoints":
|
||||||
frame["observation.environment_state"] = torch.from_numpy(keypoints[i])
|
frame["observation.environment_state"] = keypoints[i]
|
||||||
else:
|
else:
|
||||||
frame["observation.image"] = torch.from_numpy(image[i])
|
frame["observation.image"] = image[i]
|
||||||
|
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
dataset.save_episode(task=PUSHT_TASK)
|
dataset.save_episode()
|
||||||
|
|
||||||
dataset.consolidate()
|
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
dataset.push_to_hub()
|
dataset.push_to_hub()
|
||||||
|
hub_api = HfApi()
|
||||||
|
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -218,5 +239,5 @@ if __name__ == "__main__":
|
||||||
main(raw_dir, repo_id=repo_id, mode=mode)
|
main(raw_dir, repo_id=repo_id, mode=mode)
|
||||||
|
|
||||||
# Uncomment if you want to load the local dataset and explore it
|
# Uncomment if you want to load the local dataset and explore it
|
||||||
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
|
# dataset = LeRobotDataset(repo_id=repo_id)
|
||||||
# breakpoint()
|
# breakpoint()
|
||||||
|
|
|
@ -1,4 +1,22 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
# keys
|
# keys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub.constants import HF_HOME
|
||||||
|
|
||||||
OBS_ENV = "observation.environment_state"
|
OBS_ENV = "observation.environment_state"
|
||||||
OBS_ROBOT = "observation.state"
|
OBS_ROBOT = "observation.state"
|
||||||
OBS_IMAGE = "observation.image"
|
OBS_IMAGE = "observation.image"
|
||||||
|
@ -15,3 +33,13 @@ TRAINING_STEP = "training_step.json"
|
||||||
OPTIMIZER_STATE = "optimizer_state.safetensors"
|
OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||||
SCHEDULER_STATE = "scheduler_state.json"
|
SCHEDULER_STATE = "scheduler_state.json"
|
||||||
|
|
||||||
|
# cache dir
|
||||||
|
default_cache_path = Path(HF_HOME) / "lerobot"
|
||||||
|
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
|
||||||
|
|
||||||
|
if "LEROBOT_HOME" in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
|
||||||
|
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import packaging.version
|
||||||
|
|
||||||
|
V2_MESSAGE = """
|
||||||
|
The dataset you requested ({repo_id}) is in {version} format.
|
||||||
|
|
||||||
|
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||||
|
Please, use our conversion script. Modify the following command with your own task description:
|
||||||
|
```
|
||||||
|
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||||
|
--repo-id {repo_id} \\
|
||||||
|
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||||
|
```
|
||||||
|
|
||||||
|
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
|
||||||
|
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
|
||||||
|
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
|
||||||
|
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
|
||||||
|
sweatshirt.", ...
|
||||||
|
|
||||||
|
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||||
|
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||||
|
"""
|
||||||
|
|
||||||
|
V21_MESSAGE = """
|
||||||
|
The dataset you requested ({repo_id}) is in {version} format.
|
||||||
|
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
||||||
|
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
||||||
|
```
|
||||||
|
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
|
||||||
|
```
|
||||||
|
|
||||||
|
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||||
|
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||||
|
"""
|
||||||
|
|
||||||
|
FUTURE_MESSAGE = """
|
||||||
|
The dataset you requested ({repo_id}) is only available in {version} format.
|
||||||
|
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class CompatibilityError(Exception): ...
|
||||||
|
|
||||||
|
|
||||||
|
class BackwardCompatibilityError(CompatibilityError):
|
||||||
|
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||||
|
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardCompatibilityError(CompatibilityError):
|
||||||
|
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||||
|
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
|
||||||
|
super().__init__(message)
|
|
@ -13,202 +13,164 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from copy import deepcopy
|
import numpy as np
|
||||||
from math import ceil
|
|
||||||
|
|
||||||
import einops
|
from lerobot.common.datasets.utils import load_image_as_numpy
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
def get_stats_einops_patterns(dataset, num_workers=0):
|
def estimate_num_samples(
|
||||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||||
|
) -> int:
|
||||||
|
"""Heuristic to estimate the number of samples based on dataset size.
|
||||||
|
The power controls the sample growth relative to dataset size.
|
||||||
|
Lower the power for less number of samples.
|
||||||
|
|
||||||
Note: We assume the images are in channel first format
|
For default arguments, we have:
|
||||||
|
- from 1 to ~500, num_samples=100
|
||||||
|
- at 1000, num_samples=177
|
||||||
|
- at 2000, num_samples=299
|
||||||
|
- at 5000, num_samples=594
|
||||||
|
- at 10000, num_samples=1000
|
||||||
|
- at 20000, num_samples=1681
|
||||||
"""
|
"""
|
||||||
|
if dataset_len < min_num_samples:
|
||||||
dataloader = torch.utils.data.DataLoader(
|
min_num_samples = dataset_len
|
||||||
dataset,
|
return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
|
||||||
num_workers=num_workers,
|
|
||||||
batch_size=2,
|
|
||||||
shuffle=False,
|
|
||||||
)
|
|
||||||
batch = next(iter(dataloader))
|
|
||||||
|
|
||||||
stats_patterns = {}
|
|
||||||
|
|
||||||
for key in dataset.features:
|
|
||||||
# sanity check that tensors are not float64
|
|
||||||
assert batch[key].dtype != torch.float64
|
|
||||||
|
|
||||||
# if isinstance(feats_type, (VideoFrame, Image)):
|
|
||||||
if key in dataset.meta.camera_keys:
|
|
||||||
# sanity check that images are channel first
|
|
||||||
_, c, h, w = batch[key].shape
|
|
||||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
|
||||||
|
|
||||||
# sanity check that images are float32 in range [0,1]
|
|
||||||
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
|
||||||
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
|
||||||
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
|
||||||
|
|
||||||
stats_patterns[key] = "b c h w -> c 1 1"
|
|
||||||
elif batch[key].ndim == 2:
|
|
||||||
stats_patterns[key] = "b c -> c "
|
|
||||||
elif batch[key].ndim == 1:
|
|
||||||
stats_patterns[key] = "b -> 1"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{key}, {batch[key].shape}")
|
|
||||||
|
|
||||||
return stats_patterns
|
|
||||||
|
|
||||||
|
|
||||||
def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
def sample_indices(data_len: int) -> list[int]:
|
||||||
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
|
num_samples = estimate_num_samples(data_len)
|
||||||
if max_num_samples is None:
|
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
||||||
max_num_samples = len(dataset)
|
|
||||||
|
|
||||||
# for more info on why we need to set the same number of workers, see `load_from_videos`
|
|
||||||
stats_patterns = get_stats_einops_patterns(dataset, num_workers)
|
|
||||||
|
|
||||||
# mean and std will be computed incrementally while max and min will track the running value.
|
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
|
||||||
mean, std, max, min = {}, {}, {}, {}
|
_, height, width = img.shape
|
||||||
for key in stats_patterns:
|
|
||||||
mean[key] = torch.tensor(0.0).float()
|
|
||||||
std[key] = torch.tensor(0.0).float()
|
|
||||||
max[key] = torch.tensor(-float("inf")).float()
|
|
||||||
min[key] = torch.tensor(float("inf")).float()
|
|
||||||
|
|
||||||
def create_seeded_dataloader(dataset, batch_size, seed):
|
if max(width, height) < max_size_threshold:
|
||||||
generator = torch.Generator()
|
# no downsampling needed
|
||||||
generator.manual_seed(seed)
|
return img
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
num_workers=num_workers,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
drop_last=False,
|
|
||||||
generator=generator,
|
|
||||||
)
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
|
||||||
# surprises when rerunning the sampler.
|
return img[:, ::downsample_factor, ::downsample_factor]
|
||||||
first_batch = None
|
|
||||||
running_item_count = 0 # for online mean computation
|
|
||||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
|
||||||
for i, batch in enumerate(
|
|
||||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
|
||||||
):
|
|
||||||
this_batch_size = len(batch["index"])
|
|
||||||
running_item_count += this_batch_size
|
|
||||||
if first_batch is None:
|
|
||||||
first_batch = deepcopy(batch)
|
|
||||||
for key, pattern in stats_patterns.items():
|
|
||||||
batch[key] = batch[key].float()
|
|
||||||
# Numerically stable update step for mean computation.
|
|
||||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
|
||||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
|
||||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
|
||||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
|
||||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
|
||||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
|
||||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
|
||||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
|
||||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
|
||||||
|
|
||||||
if i == ceil(max_num_samples / batch_size) - 1:
|
|
||||||
break
|
|
||||||
|
|
||||||
first_batch_ = None
|
def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||||
running_item_count = 0 # for online std computation
|
sampled_indices = sample_indices(len(image_paths))
|
||||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
|
||||||
for i, batch in enumerate(
|
|
||||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
|
||||||
):
|
|
||||||
this_batch_size = len(batch["index"])
|
|
||||||
running_item_count += this_batch_size
|
|
||||||
# Sanity check to make sure the batches are still in the same order as before.
|
|
||||||
if first_batch_ is None:
|
|
||||||
first_batch_ = deepcopy(batch)
|
|
||||||
for key in stats_patterns:
|
|
||||||
assert torch.equal(first_batch_[key], first_batch[key])
|
|
||||||
for key, pattern in stats_patterns.items():
|
|
||||||
batch[key] = batch[key].float()
|
|
||||||
# Numerically stable update step for mean computation (where the mean is over squared
|
|
||||||
# residuals).See notes in the mean computation loop above.
|
|
||||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
|
||||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
|
||||||
|
|
||||||
if i == ceil(max_num_samples / batch_size) - 1:
|
images = None
|
||||||
break
|
for i, idx in enumerate(sampled_indices):
|
||||||
|
path = image_paths[idx]
|
||||||
|
# we load as uint8 to reduce memory usage
|
||||||
|
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
||||||
|
img = auto_downsample_height_width(img)
|
||||||
|
|
||||||
for key in stats_patterns:
|
if images is None:
|
||||||
std[key] = torch.sqrt(std[key])
|
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
||||||
|
|
||||||
stats = {}
|
images[i] = img
|
||||||
for key in stats_patterns:
|
|
||||||
stats[key] = {
|
return images
|
||||||
"mean": mean[key],
|
|
||||||
"std": std[key],
|
|
||||||
"max": max[key],
|
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||||
"min": min[key],
|
return {
|
||||||
|
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||||
|
"max": np.max(array, axis=axis, keepdims=keepdims),
|
||||||
|
"mean": np.mean(array, axis=axis, keepdims=keepdims),
|
||||||
|
"std": np.std(array, axis=axis, keepdims=keepdims),
|
||||||
|
"count": np.array([len(array)]),
|
||||||
}
|
}
|
||||||
return stats
|
|
||||||
|
|
||||||
|
|
||||||
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||||
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
|
ep_stats = {}
|
||||||
|
for key, data in episode_data.items():
|
||||||
|
if features[key]["dtype"] == "string":
|
||||||
|
continue # HACK: we should receive np.arrays of strings
|
||||||
|
elif features[key]["dtype"] in ["image", "video"]:
|
||||||
|
ep_ft_array = sample_images(data) # data is a list of image paths
|
||||||
|
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||||
|
keepdims = True
|
||||||
|
else:
|
||||||
|
ep_ft_array = data # data is already a np.ndarray
|
||||||
|
axes_to_reduce = 0 # compute stats over the first axis
|
||||||
|
keepdims = data.ndim == 1 # keep as np.array
|
||||||
|
|
||||||
The final stats will have the union of all data keys from each of the datasets.
|
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||||
|
|
||||||
The final stats will have the union of all data keys from each of the datasets. For instance:
|
# finally, we normalize and remove batch dim for images
|
||||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
if features[key]["dtype"] in ["image", "video"]:
|
||||||
|
ep_stats[key] = {
|
||||||
|
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return ep_stats
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||||
|
for i in range(len(stats_list)):
|
||||||
|
for fkey in stats_list[i]:
|
||||||
|
for k, v in stats_list[i][fkey].items():
|
||||||
|
if not isinstance(v, np.ndarray):
|
||||||
|
raise ValueError(
|
||||||
|
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
||||||
|
)
|
||||||
|
if v.ndim == 0:
|
||||||
|
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||||
|
if k == "count" and v.shape != (1,):
|
||||||
|
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
||||||
|
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
||||||
|
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||||
|
"""Aggregates stats for a single feature."""
|
||||||
|
means = np.stack([s["mean"] for s in stats_ft_list])
|
||||||
|
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
||||||
|
counts = np.stack([s["count"] for s in stats_ft_list])
|
||||||
|
total_count = counts.sum(axis=0)
|
||||||
|
|
||||||
|
# Prepare weighted mean by matching number of dimensions
|
||||||
|
while counts.ndim < means.ndim:
|
||||||
|
counts = np.expand_dims(counts, axis=-1)
|
||||||
|
|
||||||
|
# Compute the weighted mean
|
||||||
|
weighted_means = means * counts
|
||||||
|
total_mean = weighted_means.sum(axis=0) / total_count
|
||||||
|
|
||||||
|
# Compute the variance using the parallel algorithm
|
||||||
|
delta_means = means - total_mean
|
||||||
|
weighted_variances = (variances + delta_means**2) * counts
|
||||||
|
total_variance = weighted_variances.sum(axis=0) / total_count
|
||||||
|
|
||||||
|
return {
|
||||||
|
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
||||||
|
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
||||||
|
"mean": total_mean,
|
||||||
|
"std": np.sqrt(total_variance),
|
||||||
|
"count": total_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||||
|
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||||
|
|
||||||
|
The final stats will have the union of all data keys from each of the stats dicts.
|
||||||
|
|
||||||
|
For instance:
|
||||||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||||
- new_mean = (mean of all data)
|
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||||
|
- new_mean = (mean of all data, weighted by counts)
|
||||||
- new_std = (std of all data)
|
- new_std = (std of all data)
|
||||||
"""
|
"""
|
||||||
data_keys = set()
|
|
||||||
for dataset in ls_datasets:
|
_assert_type_and_shape(stats_list)
|
||||||
data_keys.update(dataset.meta.stats.keys())
|
|
||||||
stats = {k: {} for k in data_keys}
|
data_keys = {key for stats in stats_list for key in stats}
|
||||||
for data_key in data_keys:
|
aggregated_stats = {key: {} for key in data_keys}
|
||||||
for stat_key in ["min", "max"]:
|
|
||||||
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
for key in data_keys:
|
||||||
stats[data_key][stat_key] = einops.reduce(
|
stats_with_key = [stats[key] for stats in stats_list if key in stats]
|
||||||
torch.stack(
|
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
|
||||||
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
|
|
||||||
dim=0,
|
return aggregated_stats
|
||||||
),
|
|
||||||
"n ... -> ...",
|
|
||||||
stat_key,
|
|
||||||
)
|
|
||||||
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
|
|
||||||
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
|
||||||
# dataset, then divide by total_samples to get the overall "mean".
|
|
||||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
|
||||||
# numerical overflow!
|
|
||||||
stats[data_key]["mean"] = sum(
|
|
||||||
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
|
|
||||||
for d in ls_datasets
|
|
||||||
if data_key in d.meta.stats
|
|
||||||
)
|
|
||||||
# The derivation for standard deviation is a little more involved but is much in the same spirit as
|
|
||||||
# the computation of the mean.
|
|
||||||
# Given two sets of data where the statistics are known:
|
|
||||||
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
|
|
||||||
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
|
|
||||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
|
||||||
# numerical overflow!
|
|
||||||
stats[data_key]["std"] = torch.sqrt(
|
|
||||||
sum(
|
|
||||||
(
|
|
||||||
d.meta.stats[data_key]["std"] ** 2
|
|
||||||
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
|
|
||||||
)
|
|
||||||
* (d.num_frames / total_samples)
|
|
||||||
for d in ls_datasets
|
|
||||||
if data_key in d.meta.stats
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return stats
|
|
||||||
|
|
|
@ -83,15 +83,18 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(cfg.dataset.repo_id, str):
|
if isinstance(cfg.dataset.repo_id, str):
|
||||||
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only)
|
ds_meta = LeRobotDatasetMetadata(
|
||||||
|
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||||
|
)
|
||||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
cfg.dataset.repo_id,
|
cfg.dataset.repo_id,
|
||||||
|
root=cfg.dataset.root,
|
||||||
episodes=cfg.dataset.episodes,
|
episodes=cfg.dataset.episodes,
|
||||||
delta_timestamps=delta_timestamps,
|
delta_timestamps=delta_timestamps,
|
||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
|
revision=cfg.dataset.revision,
|
||||||
video_backend=cfg.dataset.video_backend,
|
video_backend=cfg.dataset.video_backend,
|
||||||
local_files_only=cfg.dataset.local_files_only,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||||
|
|
|
@ -38,22 +38,40 @@ def safe_stop_image_writer(func):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
|
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||||
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
|
if image_array.ndim != 3:
|
||||||
|
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||||
|
|
||||||
|
if image_array.shape[0] == 3:
|
||||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||||
image_array = image_array.transpose(1, 2, 0)
|
image_array = image_array.transpose(1, 2, 0)
|
||||||
|
|
||||||
|
elif image_array.shape[-1] != 3:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
|
||||||
|
)
|
||||||
|
|
||||||
if image_array.dtype != np.uint8:
|
if image_array.dtype != np.uint8:
|
||||||
# Assume the image is in [0, 1] range for floating-point data
|
if range_check:
|
||||||
image_array = np.clip(image_array, 0, 1)
|
max_ = image_array.max().item()
|
||||||
|
min_ = image_array.min().item()
|
||||||
|
if max_ > 1.0 or min_ < 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
"The image data type is float, which requires values in the range [0.0, 1.0]. "
|
||||||
|
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
|
||||||
|
"provide a uint8 image with values in the range [0, 255]."
|
||||||
|
)
|
||||||
|
|
||||||
image_array = (image_array * 255).astype(np.uint8)
|
image_array = (image_array * 255).astype(np.uint8)
|
||||||
|
|
||||||
return PIL.Image.fromarray(image_array)
|
return PIL.Image.fromarray(image_array)
|
||||||
|
|
||||||
|
|
||||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||||
try:
|
try:
|
||||||
if isinstance(image, np.ndarray):
|
if isinstance(image, np.ndarray):
|
||||||
img = image_array_to_image(image)
|
img = image_array_to_pil_image(image)
|
||||||
elif isinstance(image, PIL.Image.Image):
|
elif isinstance(image, PIL.Image.Image):
|
||||||
img = image
|
img = image
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -13,62 +13,67 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import shutil
|
import shutil
|
||||||
from functools import cached_property
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import packaging.version
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
import torch.utils
|
import torch.utils
|
||||||
from datasets import load_dataset
|
from datasets import concatenate_datasets, load_dataset
|
||||||
from huggingface_hub import create_repo, snapshot_download, upload_folder
|
from huggingface_hub import HfApi, snapshot_download
|
||||||
|
from huggingface_hub.constants import REPOCARD_NAME
|
||||||
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
|
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_FEATURES,
|
DEFAULT_FEATURES,
|
||||||
DEFAULT_IMAGE_PATH,
|
DEFAULT_IMAGE_PATH,
|
||||||
EPISODES_PATH,
|
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
STATS_PATH,
|
|
||||||
TASKS_PATH,
|
TASKS_PATH,
|
||||||
append_jsonlines,
|
append_jsonlines,
|
||||||
|
backward_compatible_episodes_stats,
|
||||||
check_delta_timestamps,
|
check_delta_timestamps,
|
||||||
check_timestamps_sync,
|
check_timestamps_sync,
|
||||||
check_version_compatibility,
|
check_version_compatibility,
|
||||||
create_branch,
|
|
||||||
create_empty_dataset_info,
|
create_empty_dataset_info,
|
||||||
create_lerobot_dataset_card,
|
create_lerobot_dataset_card,
|
||||||
|
embed_images,
|
||||||
get_delta_indices,
|
get_delta_indices,
|
||||||
get_episode_data_index,
|
get_episode_data_index,
|
||||||
get_features_from_robot,
|
get_features_from_robot,
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
get_hub_safe_version,
|
get_safe_version,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
|
is_valid_version,
|
||||||
load_episodes,
|
load_episodes,
|
||||||
|
load_episodes_stats,
|
||||||
load_info,
|
load_info,
|
||||||
load_stats,
|
load_stats,
|
||||||
load_tasks,
|
load_tasks,
|
||||||
serialize_dict,
|
validate_episode_buffer,
|
||||||
|
validate_frame,
|
||||||
|
write_episode,
|
||||||
|
write_episode_stats,
|
||||||
|
write_info,
|
||||||
write_json,
|
write_json,
|
||||||
write_parquet,
|
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import (
|
from lerobot.common.datasets.video_utils import (
|
||||||
VideoFrame,
|
VideoFrame,
|
||||||
decode_video_frames_torchvision,
|
decode_video_frames,
|
||||||
encode_video_frames,
|
encode_video_frames,
|
||||||
get_video_info,
|
get_video_info,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
|
||||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
CODEBASE_VERSION = "v2.1"
|
||||||
CODEBASE_VERSION = "v2.0"
|
|
||||||
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
|
||||||
|
|
||||||
|
|
||||||
class LeRobotDatasetMetadata:
|
class LeRobotDatasetMetadata:
|
||||||
|
@ -76,19 +81,36 @@ class LeRobotDatasetMetadata:
|
||||||
self,
|
self,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
root: str | Path | None = None,
|
root: str | Path | None = None,
|
||||||
local_files_only: bool = False,
|
revision: str | None = None,
|
||||||
|
force_cache_sync: bool = False,
|
||||||
):
|
):
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.local_files_only = local_files_only
|
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
if force_cache_sync:
|
||||||
|
raise FileNotFoundError
|
||||||
|
self.load_metadata()
|
||||||
|
except (FileNotFoundError, NotADirectoryError):
|
||||||
|
if is_valid_version(self.revision):
|
||||||
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||||
|
|
||||||
# Load metadata
|
|
||||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||||
self.pull_from_repo(allow_patterns="meta/")
|
self.pull_from_repo(allow_patterns="meta/")
|
||||||
|
self.load_metadata()
|
||||||
|
|
||||||
|
def load_metadata(self):
|
||||||
self.info = load_info(self.root)
|
self.info = load_info(self.root)
|
||||||
self.stats = load_stats(self.root)
|
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||||
self.tasks = load_tasks(self.root)
|
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||||
self.episodes = load_episodes(self.root)
|
self.episodes = load_episodes(self.root)
|
||||||
|
if self._version < packaging.version.parse("v2.1"):
|
||||||
|
self.stats = load_stats(self.root)
|
||||||
|
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||||
|
else:
|
||||||
|
self.episodes_stats = load_episodes_stats(self.root)
|
||||||
|
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||||
|
|
||||||
def pull_from_repo(
|
def pull_from_repo(
|
||||||
self,
|
self,
|
||||||
|
@ -98,21 +120,16 @@ class LeRobotDatasetMetadata:
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
self.repo_id,
|
self.repo_id,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
revision=self._hub_version,
|
revision=self.revision,
|
||||||
local_dir=self.root,
|
local_dir=self.root,
|
||||||
allow_patterns=allow_patterns,
|
allow_patterns=allow_patterns,
|
||||||
ignore_patterns=ignore_patterns,
|
ignore_patterns=ignore_patterns,
|
||||||
local_files_only=self.local_files_only,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def _hub_version(self) -> str | None:
|
|
||||||
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _version(self) -> str:
|
def _version(self) -> packaging.version.Version:
|
||||||
"""Codebase version used to create this dataset."""
|
"""Codebase version used to create this dataset."""
|
||||||
return self.info["codebase_version"]
|
return packaging.version.parse(self.info["codebase_version"])
|
||||||
|
|
||||||
def get_data_file_path(self, ep_index: int) -> Path:
|
def get_data_file_path(self, ep_index: int) -> Path:
|
||||||
ep_chunk = self.get_episode_chunk(ep_index)
|
ep_chunk = self.get_episode_chunk(ep_index)
|
||||||
|
@ -202,54 +219,65 @@ class LeRobotDatasetMetadata:
|
||||||
"""Max number of episodes per chunk."""
|
"""Max number of episodes per chunk."""
|
||||||
return self.info["chunks_size"]
|
return self.info["chunks_size"]
|
||||||
|
|
||||||
@property
|
def get_task_index(self, task: str) -> int | None:
|
||||||
def task_to_task_index(self) -> dict:
|
|
||||||
return {task: task_idx for task_idx, task in self.tasks.items()}
|
|
||||||
|
|
||||||
def get_task_index(self, task: str) -> int:
|
|
||||||
"""
|
"""
|
||||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||||
otherwise creates a new task_index.
|
otherwise return None.
|
||||||
"""
|
"""
|
||||||
task_index = self.task_to_task_index.get(task, None)
|
return self.task_to_task_index.get(task, None)
|
||||||
return task_index if task_index is not None else self.total_tasks
|
|
||||||
|
|
||||||
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
def add_task(self, task: str):
|
||||||
self.info["total_episodes"] += 1
|
"""
|
||||||
self.info["total_frames"] += episode_length
|
Given a task in natural language, add it to the dictionary of tasks.
|
||||||
|
"""
|
||||||
|
if task in self.task_to_task_index:
|
||||||
|
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
|
||||||
|
|
||||||
if task_index not in self.tasks:
|
task_index = self.info["total_tasks"]
|
||||||
self.info["total_tasks"] += 1
|
self.task_to_task_index[task] = task_index
|
||||||
self.tasks[task_index] = task
|
self.tasks[task_index] = task
|
||||||
|
self.info["total_tasks"] += 1
|
||||||
|
|
||||||
task_dict = {
|
task_dict = {
|
||||||
"task_index": task_index,
|
"task_index": task_index,
|
||||||
"task": task,
|
"task": task,
|
||||||
}
|
}
|
||||||
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
||||||
|
|
||||||
|
def save_episode(
|
||||||
|
self,
|
||||||
|
episode_index: int,
|
||||||
|
episode_length: int,
|
||||||
|
episode_tasks: list[str],
|
||||||
|
episode_stats: dict[str, dict],
|
||||||
|
) -> None:
|
||||||
|
self.info["total_episodes"] += 1
|
||||||
|
self.info["total_frames"] += episode_length
|
||||||
|
|
||||||
chunk = self.get_episode_chunk(episode_index)
|
chunk = self.get_episode_chunk(episode_index)
|
||||||
if chunk >= self.total_chunks:
|
if chunk >= self.total_chunks:
|
||||||
self.info["total_chunks"] += 1
|
self.info["total_chunks"] += 1
|
||||||
|
|
||||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||||
self.info["total_videos"] += len(self.video_keys)
|
self.info["total_videos"] += len(self.video_keys)
|
||||||
write_json(self.info, self.root / INFO_PATH)
|
if len(self.video_keys) > 0:
|
||||||
|
self.update_video_info()
|
||||||
|
|
||||||
|
write_info(self.info, self.root)
|
||||||
|
|
||||||
episode_dict = {
|
episode_dict = {
|
||||||
"episode_index": episode_index,
|
"episode_index": episode_index,
|
||||||
"tasks": [task],
|
"tasks": episode_tasks,
|
||||||
"length": episode_length,
|
"length": episode_length,
|
||||||
}
|
}
|
||||||
self.episodes.append(episode_dict)
|
self.episodes[episode_index] = episode_dict
|
||||||
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
|
write_episode(episode_dict, self.root)
|
||||||
|
|
||||||
# TODO(aliberts): refactor stats in save_episodes
|
self.episodes_stats[episode_index] = episode_stats
|
||||||
# image_sampling = int(self.fps / 2) # sample 2 img/s for the stats
|
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||||
# ep_stats = compute_episode_stats(episode_buffer, self.features, episode_length, image_sampling=image_sampling)
|
write_episode_stats(episode_index, episode_stats, self.root)
|
||||||
# ep_stats = serialize_dict(ep_stats)
|
|
||||||
# append_jsonlines(ep_stats, self.root / STATS_PATH)
|
|
||||||
|
|
||||||
def write_video_info(self) -> None:
|
def update_video_info(self) -> None:
|
||||||
"""
|
"""
|
||||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||||
|
@ -259,8 +287,6 @@ class LeRobotDatasetMetadata:
|
||||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||||
|
|
||||||
write_json(self.info, self.root / INFO_PATH)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
feature_keys = list(self.features)
|
feature_keys = list(self.features)
|
||||||
return (
|
return (
|
||||||
|
@ -286,7 +312,7 @@ class LeRobotDatasetMetadata:
|
||||||
"""Creates metadata for a LeRobotDataset."""
|
"""Creates metadata for a LeRobotDataset."""
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
obj.repo_id = repo_id
|
obj.repo_id = repo_id
|
||||||
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||||
|
|
||||||
obj.root.mkdir(parents=True, exist_ok=False)
|
obj.root.mkdir(parents=True, exist_ok=False)
|
||||||
|
|
||||||
|
@ -304,6 +330,7 @@ class LeRobotDatasetMetadata:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# TODO(aliberts, rcadene): implement sanity check for features
|
# TODO(aliberts, rcadene): implement sanity check for features
|
||||||
|
features = {**features, **DEFAULT_FEATURES}
|
||||||
|
|
||||||
# check if none of the features contains a "/" in their names,
|
# check if none of the features contains a "/" in their names,
|
||||||
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
||||||
|
@ -313,12 +340,13 @@ class LeRobotDatasetMetadata:
|
||||||
|
|
||||||
features = {**features, **DEFAULT_FEATURES}
|
features = {**features, **DEFAULT_FEATURES}
|
||||||
|
|
||||||
obj.tasks, obj.stats, obj.episodes = {}, {}, []
|
obj.tasks, obj.task_to_task_index = {}, {}
|
||||||
|
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||||
if len(obj.video_keys) > 0 and not use_videos:
|
if len(obj.video_keys) > 0 and not use_videos:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
write_json(obj.info, obj.root / INFO_PATH)
|
write_json(obj.info, obj.root / INFO_PATH)
|
||||||
obj.local_files_only = True
|
obj.revision = None
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
@ -331,8 +359,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
|
revision: str | None = None,
|
||||||
|
force_cache_sync: bool = False,
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
local_files_only: bool = False,
|
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -342,7 +371,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
- On your local disk in the 'root' folder. This is typically the case when you recorded your
|
- On your local disk in the 'root' folder. This is typically the case when you recorded your
|
||||||
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
|
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
|
||||||
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
|
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
|
||||||
internet connection), in that case, use local_files_only=True.
|
internet connection).
|
||||||
|
|
||||||
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
|
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
|
||||||
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
|
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
|
||||||
|
@ -362,7 +391,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
- info contains various information about the dataset like shapes, keys, fps etc.
|
- info contains various information about the dataset like shapes, keys, fps etc.
|
||||||
- stats stores the dataset statistics of the different modalities for normalization
|
- stats stores the dataset statistics of the different modalities for normalization
|
||||||
- tasks contains the prompts for each task of the dataset, which can be used for
|
- tasks contains the prompts for each task of the dataset, which can be used for
|
||||||
task-conditionned training.
|
task-conditioned training.
|
||||||
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
||||||
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
||||||
|
|
||||||
|
@ -424,24 +453,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
||||||
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
||||||
multiples of 1/fps. Defaults to 1e-4.
|
multiples of 1/fps. Defaults to 1e-4.
|
||||||
|
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||||
|
commit hash. Defaults to current codebase version tag.
|
||||||
|
sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files
|
||||||
|
are already present in the local cache, this will be faster. However, files loaded might not
|
||||||
|
be in sync with the version on the hub, especially if you specified 'revision'. Defaults to
|
||||||
|
False.
|
||||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||||
True.
|
True.
|
||||||
local_files_only (bool, optional): Flag to use local files only. If True, no requests to the hub
|
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
|
||||||
will be made. Defaults to False.
|
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
|
||||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.root = Path(root) if root else LEROBOT_HOME / repo_id
|
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.video_backend = video_backend if video_backend else "pyav"
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
|
self.video_backend = video_backend if video_backend else "torchcodec"
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
self.local_files_only = local_files_only
|
|
||||||
|
|
||||||
# Unused attributes
|
# Unused attributes
|
||||||
self.image_writer = None
|
self.image_writer = None
|
||||||
|
@ -450,64 +483,92 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.root.mkdir(exist_ok=True, parents=True)
|
self.root.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
# Load metadata
|
# Load metadata
|
||||||
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
|
self.meta = LeRobotDatasetMetadata(
|
||||||
|
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||||
# Check version
|
)
|
||||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||||
|
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||||
|
self.stats = aggregate_stats(episodes_stats)
|
||||||
|
|
||||||
# Load actual data
|
# Load actual data
|
||||||
|
try:
|
||||||
|
if force_cache_sync:
|
||||||
|
raise FileNotFoundError
|
||||||
|
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
||||||
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
|
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||||
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||||
self.download_episodes(download_videos)
|
self.download_episodes(download_videos)
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
|
|
||||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||||
|
|
||||||
# Check timestamps
|
# Check timestamps
|
||||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||||
|
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||||
|
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||||
|
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||||
|
|
||||||
# Setup delta_indices
|
# Setup delta_indices
|
||||||
if self.delta_timestamps is not None:
|
if self.delta_timestamps is not None:
|
||||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||||
|
|
||||||
# Available stats implies all videos have been encoded and dataset is iterable
|
|
||||||
self.consolidated = self.meta.stats is not None
|
|
||||||
|
|
||||||
def push_to_hub(
|
def push_to_hub(
|
||||||
self,
|
self,
|
||||||
|
branch: str | None = None,
|
||||||
tags: list | None = None,
|
tags: list | None = None,
|
||||||
license: str | None = "apache-2.0",
|
license: str | None = "apache-2.0",
|
||||||
|
tag_version: bool = True,
|
||||||
push_videos: bool = True,
|
push_videos: bool = True,
|
||||||
private: bool = False,
|
private: bool = False,
|
||||||
|
allow_patterns: list[str] | str | None = None,
|
||||||
|
upload_large_folder: bool = False,
|
||||||
**card_kwargs,
|
**card_kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.consolidated:
|
|
||||||
logging.warning(
|
|
||||||
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
|
|
||||||
"Consolidating first."
|
|
||||||
)
|
|
||||||
self.consolidate()
|
|
||||||
|
|
||||||
ignore_patterns = ["images/"]
|
ignore_patterns = ["images/"]
|
||||||
if not push_videos:
|
if not push_videos:
|
||||||
ignore_patterns.append("videos/")
|
ignore_patterns.append("videos/")
|
||||||
|
|
||||||
create_repo(
|
hub_api = HfApi()
|
||||||
|
hub_api.create_repo(
|
||||||
repo_id=self.repo_id,
|
repo_id=self.repo_id,
|
||||||
private=private,
|
private=private,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
exist_ok=True,
|
exist_ok=True,
|
||||||
)
|
)
|
||||||
|
if branch:
|
||||||
upload_folder(
|
hub_api.create_branch(
|
||||||
repo_id=self.repo_id,
|
repo_id=self.repo_id,
|
||||||
folder_path=self.root,
|
branch=branch,
|
||||||
|
revision=self.revision,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
ignore_patterns=ignore_patterns,
|
exist_ok=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
upload_kwargs = {
|
||||||
|
"repo_id": self.repo_id,
|
||||||
|
"folder_path": self.root,
|
||||||
|
"repo_type": "dataset",
|
||||||
|
"revision": branch,
|
||||||
|
"allow_patterns": allow_patterns,
|
||||||
|
"ignore_patterns": ignore_patterns,
|
||||||
|
}
|
||||||
|
if upload_large_folder:
|
||||||
|
hub_api.upload_large_folder(**upload_kwargs)
|
||||||
|
else:
|
||||||
|
hub_api.upload_folder(**upload_kwargs)
|
||||||
|
|
||||||
|
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||||
card = create_lerobot_dataset_card(
|
card = create_lerobot_dataset_card(
|
||||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||||
)
|
)
|
||||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset")
|
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||||
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
|
||||||
|
if tag_version:
|
||||||
|
with contextlib.suppress(RevisionNotFoundError):
|
||||||
|
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||||
|
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||||
|
|
||||||
def pull_from_repo(
|
def pull_from_repo(
|
||||||
self,
|
self,
|
||||||
|
@ -517,11 +578,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
self.repo_id,
|
self.repo_id,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
revision=self.meta._hub_version,
|
revision=self.revision,
|
||||||
local_dir=self.root,
|
local_dir=self.root,
|
||||||
allow_patterns=allow_patterns,
|
allow_patterns=allow_patterns,
|
||||||
ignore_patterns=ignore_patterns,
|
ignore_patterns=ignore_patterns,
|
||||||
local_files_only=self.local_files_only,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def download_episodes(self, download_videos: bool = True) -> None:
|
def download_episodes(self, download_videos: bool = True) -> None:
|
||||||
|
@ -535,16 +595,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
files = None
|
files = None
|
||||||
ignore_patterns = None if download_videos else "videos/"
|
ignore_patterns = None if download_videos else "videos/"
|
||||||
if self.episodes is not None:
|
if self.episodes is not None:
|
||||||
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
files = self.get_episodes_file_paths()
|
||||||
if len(self.meta.video_keys) > 0 and download_videos:
|
|
||||||
|
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||||
|
|
||||||
|
def get_episodes_file_paths(self) -> list[Path]:
|
||||||
|
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
|
||||||
|
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
||||||
|
if len(self.meta.video_keys) > 0:
|
||||||
video_files = [
|
video_files = [
|
||||||
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
||||||
for vid_key in self.meta.video_keys
|
for vid_key in self.meta.video_keys
|
||||||
for ep_idx in self.episodes
|
for ep_idx in episodes
|
||||||
]
|
]
|
||||||
files += video_files
|
fpaths += video_files
|
||||||
|
|
||||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
return fpaths
|
||||||
|
|
||||||
def load_hf_dataset(self) -> datasets.Dataset:
|
def load_hf_dataset(self) -> datasets.Dataset:
|
||||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||||
|
@ -557,7 +623,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
return hf_dataset
|
||||||
|
|
||||||
|
def create_hf_dataset(self) -> datasets.Dataset:
|
||||||
|
features = get_hf_features_from_features(self.features)
|
||||||
|
ft_dict = {col: [] for col in features}
|
||||||
|
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
||||||
|
|
||||||
|
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -624,7 +698,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
if key not in self.meta.video_keys
|
if key not in self.meta.video_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||||
|
@ -633,9 +707,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
item = {}
|
item = {}
|
||||||
for vid_key, query_ts in query_timestamps.items():
|
for vid_key, query_ts in query_timestamps.items():
|
||||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||||
frames = decode_video_frames_torchvision(
|
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
|
||||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
|
||||||
)
|
|
||||||
item[vid_key] = frames.squeeze(0)
|
item[vid_key] = frames.squeeze(0)
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
@ -654,8 +726,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
query_indices = None
|
query_indices = None
|
||||||
if self.delta_indices is not None:
|
if self.delta_indices is not None:
|
||||||
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
query_indices, padding = self._get_query_indices(idx, ep_idx)
|
||||||
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
|
|
||||||
query_result = self._query_hf_dataset(query_indices)
|
query_result = self._query_hf_dataset(query_indices)
|
||||||
item = {**item, **padding}
|
item = {**item, **padding}
|
||||||
for key, val in query_result.items():
|
for key, val in query_result.items():
|
||||||
|
@ -691,10 +762,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||||
return {
|
ep_buffer = {}
|
||||||
"size": 0,
|
# size and task are special cases that are not in self.features
|
||||||
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
|
ep_buffer["size"] = 0
|
||||||
}
|
ep_buffer["task"] = []
|
||||||
|
for key in self.features:
|
||||||
|
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||||
|
return ep_buffer
|
||||||
|
|
||||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||||
fpath = DEFAULT_IMAGE_PATH.format(
|
fpath = DEFAULT_IMAGE_PATH.format(
|
||||||
|
@ -716,25 +790,35 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
||||||
then needs to be called.
|
then needs to be called.
|
||||||
"""
|
"""
|
||||||
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
|
# Convert torch to numpy if needed
|
||||||
# check the dtype and shape matches, etc.
|
for name in frame:
|
||||||
|
if isinstance(frame[name], torch.Tensor):
|
||||||
|
frame[name] = frame[name].numpy()
|
||||||
|
|
||||||
|
validate_frame(frame, self.features)
|
||||||
|
|
||||||
if self.episode_buffer is None:
|
if self.episode_buffer is None:
|
||||||
self.episode_buffer = self.create_episode_buffer()
|
self.episode_buffer = self.create_episode_buffer()
|
||||||
|
|
||||||
|
# Automatically add frame_index and timestamp to episode buffer
|
||||||
frame_index = self.episode_buffer["size"]
|
frame_index = self.episode_buffer["size"]
|
||||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||||
self.episode_buffer["frame_index"].append(frame_index)
|
self.episode_buffer["frame_index"].append(frame_index)
|
||||||
self.episode_buffer["timestamp"].append(timestamp)
|
self.episode_buffer["timestamp"].append(timestamp)
|
||||||
|
|
||||||
|
# Add frame features to episode_buffer
|
||||||
for key in frame:
|
for key in frame:
|
||||||
if key not in self.features:
|
if key == "task":
|
||||||
raise ValueError(key)
|
# Note: we associate the task in natural language to its task index during `save_episode`
|
||||||
|
self.episode_buffer["task"].append(frame["task"])
|
||||||
|
continue
|
||||||
|
|
||||||
if self.features[key]["dtype"] not in ["image", "video"]:
|
if key not in self.features:
|
||||||
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
|
raise ValueError(
|
||||||
self.episode_buffer[key].append(item)
|
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
|
||||||
elif self.features[key]["dtype"] in ["image", "video"]:
|
)
|
||||||
|
|
||||||
|
if self.features[key]["dtype"] in ["image", "video"]:
|
||||||
img_path = self._get_image_file_path(
|
img_path = self._get_image_file_path(
|
||||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||||
)
|
)
|
||||||
|
@ -742,80 +826,95 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
self._save_image(frame[key], img_path)
|
self._save_image(frame[key], img_path)
|
||||||
self.episode_buffer[key].append(str(img_path))
|
self.episode_buffer[key].append(str(img_path))
|
||||||
|
else:
|
||||||
|
self.episode_buffer[key].append(frame[key])
|
||||||
|
|
||||||
self.episode_buffer["size"] += 1
|
self.episode_buffer["size"] += 1
|
||||||
|
|
||||||
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
|
def save_episode(self, episode_data: dict | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
This will save to disk the current episode in self.episode_buffer.
|
||||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
|
||||||
the hub.
|
|
||||||
|
|
||||||
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
|
Args:
|
||||||
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
|
||||||
time for video encoding.
|
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||||
|
None.
|
||||||
"""
|
"""
|
||||||
if not episode_data:
|
if not episode_data:
|
||||||
episode_buffer = self.episode_buffer
|
episode_buffer = self.episode_buffer
|
||||||
|
|
||||||
|
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
||||||
|
|
||||||
|
# size and task are special cases that won't be added to hf_dataset
|
||||||
episode_length = episode_buffer.pop("size")
|
episode_length = episode_buffer.pop("size")
|
||||||
|
tasks = episode_buffer.pop("task")
|
||||||
|
episode_tasks = list(set(tasks))
|
||||||
episode_index = episode_buffer["episode_index"]
|
episode_index = episode_buffer["episode_index"]
|
||||||
if episode_index != self.meta.total_episodes:
|
|
||||||
# TODO(aliberts): Add option to use existing episode_index
|
|
||||||
raise NotImplementedError(
|
|
||||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
|
||||||
"match the total number of episodes in the dataset. This is not supported for now."
|
|
||||||
)
|
|
||||||
|
|
||||||
if episode_length == 0:
|
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||||
raise ValueError(
|
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||||
"You must add one or several frames with `add_frame` before calling `add_episode`."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Add new tasks to the tasks dictionary
|
||||||
|
for task in episode_tasks:
|
||||||
task_index = self.meta.get_task_index(task)
|
task_index = self.meta.get_task_index(task)
|
||||||
|
if task_index is None:
|
||||||
|
self.meta.add_task(task)
|
||||||
|
|
||||||
if not set(episode_buffer.keys()) == set(self.features):
|
# Given tasks in natural language, find their corresponding task indices
|
||||||
raise ValueError()
|
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||||
|
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
if key == "index":
|
# index, episode_index, task_index are already processed above, and image and video
|
||||||
episode_buffer[key] = np.arange(
|
# are processed separately by storing image path and frame info as meta data
|
||||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||||
)
|
|
||||||
elif key == "episode_index":
|
|
||||||
episode_buffer[key] = np.full((episode_length,), episode_index)
|
|
||||||
elif key == "task_index":
|
|
||||||
episode_buffer[key] = np.full((episode_length,), task_index)
|
|
||||||
elif ft["dtype"] in ["image", "video"]:
|
|
||||||
continue
|
continue
|
||||||
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
|
|
||||||
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
|
|
||||||
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
|
|
||||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||||
else:
|
|
||||||
raise ValueError(key)
|
|
||||||
|
|
||||||
self._wait_image_writer()
|
self._wait_image_writer()
|
||||||
self._save_episode_table(episode_buffer, episode_index)
|
self._save_episode_table(episode_buffer, episode_index)
|
||||||
|
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||||
|
|
||||||
self.meta.save_episode(episode_index, episode_length, task, task_index)
|
if len(self.meta.video_keys) > 0:
|
||||||
|
|
||||||
if encode_videos and len(self.meta.video_keys) > 0:
|
|
||||||
video_paths = self.encode_episode_videos(episode_index)
|
video_paths = self.encode_episode_videos(episode_index)
|
||||||
for key in self.meta.video_keys:
|
for key in self.meta.video_keys:
|
||||||
episode_buffer[key] = video_paths[key]
|
episode_buffer[key] = video_paths[key]
|
||||||
|
|
||||||
|
# `meta.save_episode` be executed after encoding the videos
|
||||||
|
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||||
|
|
||||||
|
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||||
|
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||||
|
check_timestamps_sync(
|
||||||
|
episode_buffer["timestamp"],
|
||||||
|
episode_buffer["episode_index"],
|
||||||
|
ep_data_index_np,
|
||||||
|
self.fps,
|
||||||
|
self.tolerance_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
video_files = list(self.root.rglob("*.mp4"))
|
||||||
|
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||||
|
|
||||||
|
parquet_files = list(self.root.rglob("*.parquet"))
|
||||||
|
assert len(parquet_files) == self.num_episodes
|
||||||
|
|
||||||
|
# delete images
|
||||||
|
img_dir = self.root / "images"
|
||||||
|
if img_dir.is_dir():
|
||||||
|
shutil.rmtree(self.root / "images")
|
||||||
|
|
||||||
if not episode_data: # Reset the buffer
|
if not episode_data: # Reset the buffer
|
||||||
self.episode_buffer = self.create_episode_buffer()
|
self.episode_buffer = self.create_episode_buffer()
|
||||||
|
|
||||||
self.consolidated = False
|
|
||||||
|
|
||||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||||
|
ep_dataset = embed_images(ep_dataset)
|
||||||
|
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
|
||||||
|
self.hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
||||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
write_parquet(ep_dataset, ep_data_path)
|
ep_dataset.to_parquet(ep_data_path)
|
||||||
|
|
||||||
def clear_episode_buffer(self) -> None:
|
def clear_episode_buffer(self) -> None:
|
||||||
episode_index = self.episode_buffer["episode_index"]
|
episode_index = self.episode_buffer["episode_index"]
|
||||||
|
@ -884,38 +983,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
return video_paths
|
return video_paths
|
||||||
|
|
||||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
|
||||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
|
||||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
|
||||||
|
|
||||||
if len(self.meta.video_keys) > 0:
|
|
||||||
self.encode_videos()
|
|
||||||
self.meta.write_video_info()
|
|
||||||
|
|
||||||
if not keep_image_files:
|
|
||||||
img_dir = self.root / "images"
|
|
||||||
if img_dir.is_dir():
|
|
||||||
shutil.rmtree(self.root / "images")
|
|
||||||
|
|
||||||
video_files = list(self.root.rglob("*.mp4"))
|
|
||||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
|
||||||
|
|
||||||
parquet_files = list(self.root.rglob("*.parquet"))
|
|
||||||
assert len(parquet_files) == self.num_episodes
|
|
||||||
|
|
||||||
if run_compute_stats:
|
|
||||||
self.stop_image_writer()
|
|
||||||
# TODO(aliberts): refactor stats in save_episodes
|
|
||||||
self.meta.stats = compute_stats(self)
|
|
||||||
serialized_stats = serialize_dict(self.meta.stats)
|
|
||||||
write_json(serialized_stats, self.root / STATS_PATH)
|
|
||||||
self.consolidated = True
|
|
||||||
else:
|
|
||||||
logging.warning(
|
|
||||||
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
|
@ -944,7 +1011,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
)
|
)
|
||||||
obj.repo_id = obj.meta.repo_id
|
obj.repo_id = obj.meta.repo_id
|
||||||
obj.root = obj.meta.root
|
obj.root = obj.meta.root
|
||||||
obj.local_files_only = obj.meta.local_files_only
|
obj.revision = None
|
||||||
obj.tolerance_s = tolerance_s
|
obj.tolerance_s = tolerance_s
|
||||||
obj.image_writer = None
|
obj.image_writer = None
|
||||||
|
|
||||||
|
@ -954,19 +1021,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||||
obj.episode_buffer = obj.create_episode_buffer()
|
obj.episode_buffer = obj.create_episode_buffer()
|
||||||
|
|
||||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
|
||||||
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
|
||||||
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
|
||||||
# self.consolidate().
|
|
||||||
obj.consolidated = True
|
|
||||||
|
|
||||||
obj.episodes = None
|
obj.episodes = None
|
||||||
obj.hf_dataset = None
|
obj.hf_dataset = obj.create_hf_dataset()
|
||||||
obj.image_transforms = None
|
obj.image_transforms = None
|
||||||
obj.delta_timestamps = None
|
obj.delta_timestamps = None
|
||||||
obj.delta_indices = None
|
obj.delta_indices = None
|
||||||
obj.episode_data_index = None
|
obj.episode_data_index = None
|
||||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
obj.video_backend = video_backend if video_backend is not None else "torchcodec"
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
@ -986,12 +1047,11 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
tolerances_s: dict | None = None,
|
tolerances_s: dict | None = None,
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
local_files_only: bool = False,
|
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_ids = repo_ids
|
self.repo_ids = repo_ids
|
||||||
self.root = Path(root) if root else LEROBOT_HOME
|
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||||
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||||
# are handled by this class.
|
# are handled by this class.
|
||||||
|
@ -1004,7 +1064,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
delta_timestamps=delta_timestamps,
|
delta_timestamps=delta_timestamps,
|
||||||
tolerance_s=self.tolerances_s[repo_id],
|
tolerance_s=self.tolerances_s[repo_id],
|
||||||
download_videos=download_videos,
|
download_videos=download_videos,
|
||||||
local_files_only=local_files_only,
|
|
||||||
video_backend=video_backend,
|
video_backend=video_backend,
|
||||||
)
|
)
|
||||||
for repo_id in repo_ids
|
for repo_id in repo_ids
|
||||||
|
@ -1032,7 +1091,10 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.stats = aggregate_stats(self._datasets)
|
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
||||||
|
# with multiple robots of different ranges. Instead we should have one normalization
|
||||||
|
# per robot.
|
||||||
|
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def repo_id_to_index(self):
|
def repo_id_to_index(self):
|
||||||
|
|
|
@ -1,56 +0,0 @@
|
||||||
## Using / Updating `CODEBASE_VERSION` (for maintainers)
|
|
||||||
|
|
||||||
Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
|
|
||||||
the datasets with our code, we use a `CODEBASE_VERSION` (defined in
|
|
||||||
lerobot/common/datasets/lerobot_dataset.py) variable.
|
|
||||||
|
|
||||||
For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
|
|
||||||
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
|
|
||||||
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
|
|
||||||
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
|
|
||||||
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
|
|
||||||
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
|
|
||||||
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5)
|
|
||||||
- [v1.6](https://huggingface.co/datasets/lerobot/pusht/tree/v1.6) <-- last version
|
|
||||||
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
|
|
||||||
|
|
||||||
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
|
|
||||||
`info.json` metadata.
|
|
||||||
|
|
||||||
### Uploading a new dataset
|
|
||||||
If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
|
|
||||||
compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
|
|
||||||
dataset with the current `CODEBASE_VERSION`.
|
|
||||||
|
|
||||||
### Updating an existing dataset
|
|
||||||
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
|
|
||||||
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
|
|
||||||
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
|
|
||||||
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
|
|
||||||
codebase won't be affected by your change and backward compatibility is maintained.
|
|
||||||
|
|
||||||
However, you will need to update the version of ALL the other datasets so that they have the new
|
|
||||||
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
|
|
||||||
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
|
|
||||||
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
|
|
||||||
|
|
||||||
```python
|
|
||||||
from huggingface_hub import HfApi
|
|
||||||
|
|
||||||
from lerobot import available_datasets
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
|
||||||
|
|
||||||
api = HfApi()
|
|
||||||
|
|
||||||
for repo_id in available_datasets:
|
|
||||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
|
||||||
branches = [b.name for b in dataset_info.branches]
|
|
||||||
if CODEBASE_VERSION in branches:
|
|
||||||
print(f"{repo_id} already @{CODEBASE_VERSION}, skipping.")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# Now create a branch named after the new version by branching out from "main"
|
|
||||||
# which is expected to be the preceding version
|
|
||||||
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
|
|
||||||
print(f"{repo_id} successfully updated @{CODEBASE_VERSION}")
|
|
||||||
```
|
|
|
@ -152,7 +152,7 @@ def download_raw(raw_dir: Path, repo_id: str):
|
||||||
stacklevel=1,
|
stacklevel=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send warning if raw_dir isn't well formated
|
# Send warning if raw_dir isn't well formatted
|
||||||
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that
|
f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that
|
||||||
|
|
|
@ -68,9 +68,9 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||||
modality_df,
|
modality_df,
|
||||||
on="timestamp_utc",
|
on="timestamp_utc",
|
||||||
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
|
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
|
||||||
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
|
# matching timestamps that are too far apart, in order to fit the backward constraints. It's not the case for "nearest".
|
||||||
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
|
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
|
||||||
# are too far appart.
|
# are too far apart.
|
||||||
direction="nearest",
|
direction="nearest",
|
||||||
tolerance=pd.Timedelta(f"{1 / fps} seconds"),
|
tolerance=pd.Timedelta(f"{1 / fps} seconds"),
|
||||||
)
|
)
|
||||||
|
@ -126,7 +126,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||||
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
||||||
|
|
||||||
# sanity check the video paths are well formated
|
# sanity check the video paths are well formatted
|
||||||
for key in df:
|
for key in df:
|
||||||
if "observation.images." not in key:
|
if "observation.images." not in key:
|
||||||
continue
|
continue
|
||||||
|
@ -143,7 +143,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||||
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
|
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
|
||||||
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
|
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
|
||||||
|
|
||||||
# sanity check the video path is well formated
|
# sanity check the video path is well formatted
|
||||||
video_path = videos_dir.parent / data_dict[key][0]["path"]
|
video_path = videos_dir.parent / data_dict[key][0]["path"]
|
||||||
if not video_path.exists():
|
if not video_path.exists():
|
||||||
raise ValueError(f"Video file not found in {video_path}")
|
raise ValueError(f"Video file not found in {video_path}")
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
For all datasets in the RLDS format.
|
For all datasets in the RLDS format.
|
||||||
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
||||||
|
|
||||||
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
|
NOTE: You need to install tensorflow and tensorflow_datasets before running this script.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
|
|
|
@ -13,10 +13,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import contextlib
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import textwrap
|
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -27,14 +27,21 @@ from typing import Any
|
||||||
import datasets
|
import datasets
|
||||||
import jsonlines
|
import jsonlines
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow.compute as pc
|
import packaging.version
|
||||||
import torch
|
import torch
|
||||||
from datasets.table import embed_table_storage
|
from datasets.table import embed_table_storage
|
||||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||||
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from lerobot.common.datasets.backward_compatibility import (
|
||||||
|
V21_MESSAGE,
|
||||||
|
BackwardCompatibilityError,
|
||||||
|
ForwardCompatibilityError,
|
||||||
|
)
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||||
|
@ -42,6 +49,7 @@ DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||||
INFO_PATH = "meta/info.json"
|
INFO_PATH = "meta/info.json"
|
||||||
EPISODES_PATH = "meta/episodes.jsonl"
|
EPISODES_PATH = "meta/episodes.jsonl"
|
||||||
STATS_PATH = "meta/stats.json"
|
STATS_PATH = "meta/stats.json"
|
||||||
|
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||||
TASKS_PATH = "meta/tasks.jsonl"
|
TASKS_PATH = "meta/tasks.jsonl"
|
||||||
|
|
||||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||||
|
@ -112,17 +120,26 @@ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
||||||
|
|
||||||
|
|
||||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||||
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
serialized_dict = {}
|
||||||
|
for key, value in flatten_dict(stats).items():
|
||||||
|
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||||
|
serialized_dict[key] = value.tolist()
|
||||||
|
elif isinstance(value, np.generic):
|
||||||
|
serialized_dict[key] = value.item()
|
||||||
|
elif isinstance(value, (int, float)):
|
||||||
|
serialized_dict[key] = value
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
||||||
return unflatten_dict(serialized_dict)
|
return unflatten_dict(serialized_dict)
|
||||||
|
|
||||||
|
|
||||||
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
|
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||||
# Embed image bytes into the table before saving to parquet
|
# Embed image bytes into the table before saving to parquet
|
||||||
format = dataset.format
|
format = dataset.format
|
||||||
dataset = dataset.with_format("arrow")
|
dataset = dataset.with_format("arrow")
|
||||||
dataset = dataset.map(embed_table_storage, batched=False)
|
dataset = dataset.map(embed_table_storage, batched=False)
|
||||||
dataset = dataset.with_format(**format)
|
dataset = dataset.with_format(**format)
|
||||||
dataset.to_parquet(fpath)
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def load_json(fpath: Path) -> Any:
|
def load_json(fpath: Path) -> Any:
|
||||||
|
@ -153,6 +170,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
|
||||||
writer.write(data)
|
writer.write(data)
|
||||||
|
|
||||||
|
|
||||||
|
def write_info(info: dict, local_dir: Path):
|
||||||
|
write_json(info, local_dir / INFO_PATH)
|
||||||
|
|
||||||
|
|
||||||
def load_info(local_dir: Path) -> dict:
|
def load_info(local_dir: Path) -> dict:
|
||||||
info = load_json(local_dir / INFO_PATH)
|
info = load_json(local_dir / INFO_PATH)
|
||||||
for ft in info["features"].values():
|
for ft in info["features"].values():
|
||||||
|
@ -160,29 +181,76 @@ def load_info(local_dir: Path) -> dict:
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
def load_stats(local_dir: Path) -> dict:
|
def write_stats(stats: dict, local_dir: Path):
|
||||||
if not (local_dir / STATS_PATH).exists():
|
serialized_stats = serialize_dict(stats)
|
||||||
return None
|
write_json(serialized_stats, local_dir / STATS_PATH)
|
||||||
stats = load_json(local_dir / STATS_PATH)
|
|
||||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
|
||||||
|
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
||||||
|
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||||
return unflatten_dict(stats)
|
return unflatten_dict(stats)
|
||||||
|
|
||||||
|
|
||||||
def load_tasks(local_dir: Path) -> dict:
|
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||||
|
if not (local_dir / STATS_PATH).exists():
|
||||||
|
return None
|
||||||
|
stats = load_json(local_dir / STATS_PATH)
|
||||||
|
return cast_stats_to_numpy(stats)
|
||||||
|
|
||||||
|
|
||||||
|
def write_task(task_index: int, task: dict, local_dir: Path):
|
||||||
|
task_dict = {
|
||||||
|
"task_index": task_index,
|
||||||
|
"task": task,
|
||||||
|
}
|
||||||
|
append_jsonlines(task_dict, local_dir / TASKS_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||||
|
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||||
|
return tasks, task_to_task_index
|
||||||
|
|
||||||
|
|
||||||
|
def write_episode(episode: dict, local_dir: Path):
|
||||||
|
append_jsonlines(episode, local_dir / EPISODES_PATH)
|
||||||
|
|
||||||
|
|
||||||
def load_episodes(local_dir: Path) -> dict:
|
def load_episodes(local_dir: Path) -> dict:
|
||||||
return load_jsonlines(local_dir / EPISODES_PATH)
|
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
||||||
|
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||||
|
|
||||||
|
|
||||||
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
|
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||||
|
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
||||||
|
# is a dictionary of stats and not an integer.
|
||||||
|
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
||||||
|
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
def load_episodes_stats(local_dir: Path) -> dict:
|
||||||
|
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
|
||||||
|
return {
|
||||||
|
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
||||||
|
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def backward_compatible_episodes_stats(
|
||||||
|
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||||
|
) -> dict[str, dict[str, np.ndarray]]:
|
||||||
|
return {ep_idx: stats for ep_idx in episodes}
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_as_numpy(
|
||||||
|
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
|
||||||
|
) -> np.ndarray:
|
||||||
img = PILImage.open(fpath).convert("RGB")
|
img = PILImage.open(fpath).convert("RGB")
|
||||||
img_array = np.array(img, dtype=dtype)
|
img_array = np.array(img, dtype=dtype)
|
||||||
if channel_first: # (H, W, C) -> (C, H, W)
|
if channel_first: # (H, W, C) -> (C, H, W)
|
||||||
img_array = np.transpose(img_array, (2, 0, 1))
|
img_array = np.transpose(img_array, (2, 0, 1))
|
||||||
if "float" in dtype:
|
if np.issubdtype(dtype, np.floating):
|
||||||
img_array /= 255.0
|
img_array /= 255.0
|
||||||
return img_array
|
return img_array
|
||||||
|
|
||||||
|
@ -201,77 +269,95 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
elif first_item is None:
|
elif first_item is None:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||||
return items_dict
|
return items_dict
|
||||||
|
|
||||||
|
|
||||||
def _get_major_minor(version: str) -> tuple[int]:
|
def is_valid_version(version: str) -> bool:
|
||||||
split = version.strip("v").split(".")
|
try:
|
||||||
return int(split[0]), int(split[1])
|
packaging.version.parse(version)
|
||||||
|
return True
|
||||||
|
except packaging.version.InvalidVersion:
|
||||||
class BackwardCompatibilityError(Exception):
|
return False
|
||||||
def __init__(self, repo_id, version):
|
|
||||||
message = textwrap.dedent(f"""
|
|
||||||
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
|
|
||||||
|
|
||||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
|
||||||
Please, use our conversion script. Modify the following command with your own task description:
|
|
||||||
```
|
|
||||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
|
||||||
--repo-id {repo_id} \\
|
|
||||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
|
||||||
```
|
|
||||||
|
|
||||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
|
|
||||||
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
|
|
||||||
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
|
|
||||||
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
|
|
||||||
|
|
||||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
|
||||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
|
||||||
""")
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
def check_version_compatibility(
|
def check_version_compatibility(
|
||||||
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
|
repo_id: str,
|
||||||
|
version_to_check: str | packaging.version.Version,
|
||||||
|
current_version: str | packaging.version.Version,
|
||||||
|
enforce_breaking_major: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
current_major, _ = _get_major_minor(current_version)
|
v_check = (
|
||||||
major_to_check, _ = _get_major_minor(version_to_check)
|
packaging.version.parse(version_to_check)
|
||||||
if major_to_check < current_major and enforce_breaking_major:
|
if not isinstance(version_to_check, packaging.version.Version)
|
||||||
raise BackwardCompatibilityError(repo_id, version_to_check)
|
else version_to_check
|
||||||
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
|
||||||
logging.warning(
|
|
||||||
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
|
||||||
codebase. The current codebase version is {current_version}. You should be fine since
|
|
||||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
|
||||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
|
||||||
)
|
)
|
||||||
|
v_current = (
|
||||||
|
packaging.version.parse(current_version)
|
||||||
|
if not isinstance(current_version, packaging.version.Version)
|
||||||
|
else current_version
|
||||||
|
)
|
||||||
|
if v_check.major < v_current.major and enforce_breaking_major:
|
||||||
|
raise BackwardCompatibilityError(repo_id, v_check)
|
||||||
|
elif v_check.minor < v_current.minor:
|
||||||
|
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
||||||
|
|
||||||
|
|
||||||
def get_hub_safe_version(repo_id: str, version: str) -> str:
|
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||||
|
"""Returns available valid versions (branches and tags) on given repo."""
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||||
branches = [b.name for b in dataset_info.branches]
|
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
|
||||||
if version not in branches:
|
repo_versions = []
|
||||||
num_version = float(version.strip("v"))
|
for ref in repo_refs:
|
||||||
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
|
with contextlib.suppress(packaging.version.InvalidVersion):
|
||||||
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
|
repo_versions.append(packaging.version.parse(ref))
|
||||||
raise BackwardCompatibilityError(repo_id, version)
|
|
||||||
|
|
||||||
logging.warning(
|
return repo_versions
|
||||||
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
|
||||||
codebase. The following versions are available: {branches}.
|
|
||||||
The requested version ('{version}') is not found. You should be fine since
|
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
|
||||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
"""
|
||||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
Returns the version if available on repo or the latest compatible one.
|
||||||
|
Otherwise, will throw a `CompatibilityError`.
|
||||||
|
"""
|
||||||
|
target_version = (
|
||||||
|
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
|
||||||
)
|
)
|
||||||
if "main" not in branches:
|
hub_versions = get_repo_versions(repo_id)
|
||||||
raise ValueError(f"Version 'main' not found on {repo_id}")
|
|
||||||
return "main"
|
if not hub_versions:
|
||||||
else:
|
raise RevisionNotFoundError(
|
||||||
return version
|
f"""Your dataset must be tagged with a codebase version.
|
||||||
|
Assuming _version_ is the codebase_version value in the info.json, you can run this:
|
||||||
|
```python
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
hub_api = HfApi()
|
||||||
|
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
if target_version in hub_versions:
|
||||||
|
return f"v{target_version}"
|
||||||
|
|
||||||
|
compatibles = [
|
||||||
|
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
|
||||||
|
]
|
||||||
|
if compatibles:
|
||||||
|
return_version = max(compatibles)
|
||||||
|
if return_version < target_version:
|
||||||
|
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
||||||
|
return f"v{return_version}"
|
||||||
|
|
||||||
|
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
||||||
|
if lower_major:
|
||||||
|
raise BackwardCompatibilityError(repo_id, max(lower_major))
|
||||||
|
|
||||||
|
upper_versions = [v for v in hub_versions if v > target_version]
|
||||||
|
assert len(upper_versions) > 0
|
||||||
|
raise ForwardCompatibilityError(repo_id, min(upper_versions))
|
||||||
|
|
||||||
|
|
||||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||||
|
@ -283,11 +369,20 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||||
hf_features[key] = datasets.Image()
|
hf_features[key] = datasets.Image()
|
||||||
elif ft["shape"] == (1,):
|
elif ft["shape"] == (1,):
|
||||||
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
||||||
else:
|
elif len(ft["shape"]) == 1:
|
||||||
assert len(ft["shape"]) == 1
|
|
||||||
hf_features[key] = datasets.Sequence(
|
hf_features[key] = datasets.Sequence(
|
||||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||||
)
|
)
|
||||||
|
elif len(ft["shape"]) == 2:
|
||||||
|
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
|
||||||
|
elif len(ft["shape"]) == 3:
|
||||||
|
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
|
||||||
|
elif len(ft["shape"]) == 4:
|
||||||
|
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
|
||||||
|
elif len(ft["shape"]) == 5:
|
||||||
|
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Corresponding feature is not valid: {ft}")
|
||||||
|
|
||||||
return datasets.Features(hf_features)
|
return datasets.Features(hf_features)
|
||||||
|
|
||||||
|
@ -358,88 +453,85 @@ def create_empty_dataset_info(
|
||||||
|
|
||||||
|
|
||||||
def get_episode_data_index(
|
def get_episode_data_index(
|
||||||
episode_dicts: list[dict], episodes: list[int] | None = None
|
episode_dicts: dict[dict], episodes: list[int] | None = None
|
||||||
) -> dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
|
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
||||||
if episodes is not None:
|
if episodes is not None:
|
||||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||||
|
|
||||||
cumulative_lenghts = list(accumulate(episode_lengths.values()))
|
cumulative_lengths = list(accumulate(episode_lengths.values()))
|
||||||
return {
|
return {
|
||||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
"from": torch.LongTensor([0] + cumulative_lengths[:-1]),
|
||||||
"to": torch.LongTensor(cumulative_lenghts),
|
"to": torch.LongTensor(cumulative_lengths),
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_total_episode(
|
|
||||||
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
|
|
||||||
) -> dict[str, torch.Tensor]:
|
|
||||||
episode_indices = sorted(hf_dataset.unique("episode_index"))
|
|
||||||
total_episodes = len(episode_indices)
|
|
||||||
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
|
|
||||||
raise ValueError("episode_index values are not sorted and contiguous.")
|
|
||||||
return total_episodes
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
|
|
||||||
episode_lengths = []
|
|
||||||
table = hf_dataset.data.table
|
|
||||||
total_episodes = calculate_total_episode(hf_dataset)
|
|
||||||
for ep_idx in range(total_episodes):
|
|
||||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
|
||||||
episode_lengths.insert(ep_idx, len(ep_table))
|
|
||||||
|
|
||||||
cumulative_lenghts = list(accumulate(episode_lengths))
|
|
||||||
return {
|
|
||||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
|
||||||
"to": torch.LongTensor(cumulative_lenghts),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def check_timestamps_sync(
|
def check_timestamps_sync(
|
||||||
hf_dataset: datasets.Dataset,
|
timestamps: np.ndarray,
|
||||||
episode_data_index: dict[str, torch.Tensor],
|
episode_indices: np.ndarray,
|
||||||
|
episode_data_index: dict[str, np.ndarray],
|
||||||
fps: int,
|
fps: int,
|
||||||
tolerance_s: float,
|
tolerance_s: float,
|
||||||
raise_value_error: bool = True,
|
raise_value_error: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
|
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
|
||||||
account for possible numerical error.
|
to account for possible numerical error.
|
||||||
"""
|
|
||||||
timestamps = torch.stack(hf_dataset["timestamp"])
|
|
||||||
diffs = torch.diff(timestamps)
|
|
||||||
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
|
|
||||||
|
|
||||||
# We mask differences between the timestamp at the end of an episode
|
Args:
|
||||||
# and the one at the start of the next episode since these are expected
|
timestamps (np.ndarray): Array of timestamps in seconds.
|
||||||
# to be outside tolerance.
|
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
|
||||||
mask = torch.ones(len(diffs), dtype=torch.bool)
|
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
|
||||||
ignored_diffs = episode_data_index["to"][:-1] - 1
|
which identifies indices for the end of each episode.
|
||||||
|
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
|
||||||
|
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
|
||||||
|
raise_value_error (bool): Whether to raise a ValueError if the check fails.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the check fails and `raise_value_error` is True.
|
||||||
|
"""
|
||||||
|
if timestamps.shape != episode_indices.shape:
|
||||||
|
raise ValueError(
|
||||||
|
"timestamps and episode_indices should have the same shape. "
|
||||||
|
f"Found {timestamps.shape=} and {episode_indices.shape=}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Consecutive differences
|
||||||
|
diffs = np.diff(timestamps)
|
||||||
|
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
|
||||||
|
|
||||||
|
# Mask to ignore differences at the boundaries between episodes
|
||||||
|
mask = np.ones(len(diffs), dtype=bool)
|
||||||
|
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
|
||||||
mask[ignored_diffs] = False
|
mask[ignored_diffs] = False
|
||||||
filtered_within_tolerance = within_tolerance[mask]
|
filtered_within_tolerance = within_tolerance[mask]
|
||||||
|
|
||||||
if not torch.all(filtered_within_tolerance):
|
# Check if all remaining diffs are within tolerance
|
||||||
|
if not np.all(filtered_within_tolerance):
|
||||||
# Track original indices before masking
|
# Track original indices before masking
|
||||||
original_indices = torch.arange(len(diffs))
|
original_indices = np.arange(len(diffs))
|
||||||
filtered_indices = original_indices[mask]
|
filtered_indices = original_indices[mask]
|
||||||
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
|
outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
|
||||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||||
episode_indices = torch.stack(hf_dataset["episode_index"])
|
|
||||||
|
|
||||||
outside_tolerances = []
|
outside_tolerances = []
|
||||||
for idx in outside_tolerance_indices:
|
for idx in outside_tolerance_indices:
|
||||||
entry = {
|
entry = {
|
||||||
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
||||||
"diff": diffs[idx],
|
"diff": diffs[idx],
|
||||||
"episode_index": episode_indices[idx].item(),
|
"episode_index": episode_indices[idx].item()
|
||||||
|
if hasattr(episode_indices[idx], "item")
|
||||||
|
else episode_indices[idx],
|
||||||
}
|
}
|
||||||
outside_tolerances.append(entry)
|
outside_tolerances.append(entry)
|
||||||
|
|
||||||
if raise_value_error:
|
if raise_value_error:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
||||||
This might be due to synchronization issues with timestamps during data collection.
|
This might be due to synchronization issues during data collection.
|
||||||
\n{pformat(outside_tolerances)}"""
|
\n{pformat(outside_tolerances)}"""
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
@ -604,3 +696,118 @@ class IterableNamespace(SimpleNamespace):
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return vars(self).keys()
|
return vars(self).keys()
|
||||||
|
|
||||||
|
|
||||||
|
def validate_frame(frame: dict, features: dict):
|
||||||
|
optional_features = {"timestamp"}
|
||||||
|
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
||||||
|
actual_features = set(frame.keys())
|
||||||
|
|
||||||
|
error_message = validate_features_presence(actual_features, expected_features, optional_features)
|
||||||
|
|
||||||
|
if "task" in frame:
|
||||||
|
error_message += validate_feature_string("task", frame["task"])
|
||||||
|
|
||||||
|
common_features = actual_features & (expected_features | optional_features)
|
||||||
|
for name in common_features - {"task"}:
|
||||||
|
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||||
|
|
||||||
|
if error_message:
|
||||||
|
raise ValueError(error_message)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_features_presence(
|
||||||
|
actual_features: set[str], expected_features: set[str], optional_features: set[str]
|
||||||
|
):
|
||||||
|
error_message = ""
|
||||||
|
missing_features = expected_features - actual_features
|
||||||
|
extra_features = actual_features - (expected_features | optional_features)
|
||||||
|
|
||||||
|
if missing_features or extra_features:
|
||||||
|
error_message += "Feature mismatch in `frame` dictionary:\n"
|
||||||
|
if missing_features:
|
||||||
|
error_message += f"Missing features: {missing_features}\n"
|
||||||
|
if extra_features:
|
||||||
|
error_message += f"Extra features: {extra_features}\n"
|
||||||
|
|
||||||
|
return error_message
|
||||||
|
|
||||||
|
|
||||||
|
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||||
|
expected_dtype = feature["dtype"]
|
||||||
|
expected_shape = feature["shape"]
|
||||||
|
if is_valid_numpy_dtype_string(expected_dtype):
|
||||||
|
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||||
|
elif expected_dtype in ["image", "video"]:
|
||||||
|
return validate_feature_image_or_video(name, expected_shape, value)
|
||||||
|
elif expected_dtype == "string":
|
||||||
|
return validate_feature_string(name, value)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_feature_numpy_array(
|
||||||
|
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
||||||
|
):
|
||||||
|
error_message = ""
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
actual_dtype = value.dtype
|
||||||
|
actual_shape = value.shape
|
||||||
|
|
||||||
|
if actual_dtype != np.dtype(expected_dtype):
|
||||||
|
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
|
||||||
|
|
||||||
|
if actual_shape != expected_shape:
|
||||||
|
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
|
||||||
|
else:
|
||||||
|
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
|
||||||
|
|
||||||
|
return error_message
|
||||||
|
|
||||||
|
|
||||||
|
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||||
|
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||||
|
error_message = ""
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
actual_shape = value.shape
|
||||||
|
c, h, w = expected_shape
|
||||||
|
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||||
|
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||||
|
elif isinstance(value, PILImage.Image):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
|
||||||
|
|
||||||
|
return error_message
|
||||||
|
|
||||||
|
|
||||||
|
def validate_feature_string(name: str, value: str):
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
|
||||||
|
if "size" not in episode_buffer:
|
||||||
|
raise ValueError("size key not found in episode_buffer")
|
||||||
|
|
||||||
|
if "task" not in episode_buffer:
|
||||||
|
raise ValueError("task key not found in episode_buffer")
|
||||||
|
|
||||||
|
if episode_buffer["episode_index"] != total_episodes:
|
||||||
|
# TODO(aliberts): Add option to use existing episode_index
|
||||||
|
raise NotImplementedError(
|
||||||
|
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||||
|
"match the total number of episodes already in the dataset. This is not supported for now."
|
||||||
|
)
|
||||||
|
|
||||||
|
if episode_buffer["size"] == 0:
|
||||||
|
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
||||||
|
|
||||||
|
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
||||||
|
if not buffer_keys == set(features):
|
||||||
|
raise ValueError(
|
||||||
|
f"Features from `episode_buffer` don't match the ones in `features`."
|
||||||
|
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||||
|
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||||
|
)
|
||||||
|
|
|
@ -31,6 +31,7 @@ from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
|
||||||
|
|
||||||
LOCAL_DIR = Path("data/")
|
LOCAL_DIR = Path("data/")
|
||||||
|
|
||||||
|
# spellchecker:off
|
||||||
ALOHA_MOBILE_INFO = {
|
ALOHA_MOBILE_INFO = {
|
||||||
"robot_config": AlohaRobotConfig(),
|
"robot_config": AlohaRobotConfig(),
|
||||||
"license": "mit",
|
"license": "mit",
|
||||||
|
@ -856,6 +857,7 @@ DATASETS = {
|
||||||
}""").lstrip(),
|
}""").lstrip(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
# spellchecker:on
|
||||||
|
|
||||||
|
|
||||||
def batch_convert():
|
def batch_convert():
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
"""
|
"""
|
||||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
|
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
|
||||||
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
||||||
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
|
for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
|
||||||
|
|
||||||
We support 3 different scenarios for these tasks (see instructions below):
|
We support 3 different scenarios for these tasks (see instructions below):
|
||||||
1. Single task dataset: all episodes of your dataset have the same single task.
|
1. Single task dataset: all episodes of your dataset have the same single task.
|
||||||
|
@ -130,7 +130,7 @@ from lerobot.common.datasets.utils import (
|
||||||
create_branch,
|
create_branch,
|
||||||
create_lerobot_dataset_card,
|
create_lerobot_dataset_card,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
get_hub_safe_version,
|
get_safe_version,
|
||||||
load_json,
|
load_json,
|
||||||
unflatten_dict,
|
unflatten_dict,
|
||||||
write_json,
|
write_json,
|
||||||
|
@ -443,7 +443,7 @@ def convert_dataset(
|
||||||
test_branch: str | None = None,
|
test_branch: str | None = None,
|
||||||
**card_kwargs,
|
**card_kwargs,
|
||||||
):
|
):
|
||||||
v1 = get_hub_safe_version(repo_id, V16)
|
v1 = get_safe_version(repo_id, V16)
|
||||||
v1x_dir = local_dir / V16 / repo_id
|
v1x_dir = local_dir / V16 / repo_id
|
||||||
v20_dir = local_dir / V20 / repo_id
|
v20_dir = local_dir / V20 / repo_id
|
||||||
v1x_dir.mkdir(parents=True, exist_ok=True)
|
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from datasets import get_dataset_config_info
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
from lerobot import available_datasets
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
|
from lerobot.common.datasets.utils import INFO_PATH, write_info
|
||||||
|
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
|
||||||
|
|
||||||
|
LOCAL_DIR = Path("data/")
|
||||||
|
|
||||||
|
hub_api = HfApi()
|
||||||
|
|
||||||
|
|
||||||
|
def fix_dataset(repo_id: str) -> str:
|
||||||
|
if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
|
||||||
|
return f"{repo_id}: skipped (not in {V20})."
|
||||||
|
|
||||||
|
dataset_info = get_dataset_config_info(repo_id, "default")
|
||||||
|
with SuppressWarnings():
|
||||||
|
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
|
||||||
|
|
||||||
|
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
|
||||||
|
parquet_features = set(dataset_info.features)
|
||||||
|
|
||||||
|
diff_parquet_meta = parquet_features - meta_features
|
||||||
|
diff_meta_parquet = meta_features - parquet_features
|
||||||
|
|
||||||
|
if diff_parquet_meta:
|
||||||
|
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
|
||||||
|
|
||||||
|
if not diff_meta_parquet:
|
||||||
|
return f"{repo_id}: skipped (no diff)"
|
||||||
|
|
||||||
|
if diff_meta_parquet:
|
||||||
|
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
||||||
|
assert diff_meta_parquet == {"language_instruction"}
|
||||||
|
lerobot_metadata.features.pop("language_instruction")
|
||||||
|
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||||
|
commit_info = hub_api.upload_file(
|
||||||
|
path_or_fileobj=lerobot_metadata.root / INFO_PATH,
|
||||||
|
path_in_repo=INFO_PATH,
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
revision=V20,
|
||||||
|
commit_message="Remove 'language_instruction'",
|
||||||
|
create_pr=True,
|
||||||
|
)
|
||||||
|
return f"{repo_id}: success - PR: {commit_info.pr_url}"
|
||||||
|
|
||||||
|
|
||||||
|
def batch_fix():
|
||||||
|
status = {}
|
||||||
|
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
logfile = LOCAL_DIR / "fix_features_v20.txt"
|
||||||
|
for num, repo_id in enumerate(available_datasets):
|
||||||
|
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||||
|
print("---------------------------------------------------------")
|
||||||
|
try:
|
||||||
|
status = fix_dataset(repo_id)
|
||||||
|
except Exception:
|
||||||
|
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||||
|
|
||||||
|
logging.info(status)
|
||||||
|
with open(logfile, "a") as file:
|
||||||
|
file.write(status + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
batch_fix()
|
|
@ -0,0 +1,54 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
from lerobot import available_datasets
|
||||||
|
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
|
||||||
|
|
||||||
|
LOCAL_DIR = Path("data/")
|
||||||
|
|
||||||
|
|
||||||
|
def batch_convert():
|
||||||
|
status = {}
|
||||||
|
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
logfile = LOCAL_DIR / "conversion_log_v21.txt"
|
||||||
|
hub_api = HfApi()
|
||||||
|
for num, repo_id in enumerate(available_datasets):
|
||||||
|
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||||
|
print("---------------------------------------------------------")
|
||||||
|
try:
|
||||||
|
if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
|
||||||
|
status = f"{repo_id}: success (already in {V21})."
|
||||||
|
else:
|
||||||
|
convert_dataset(repo_id)
|
||||||
|
status = f"{repo_id}: success."
|
||||||
|
except Exception:
|
||||||
|
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||||
|
|
||||||
|
with open(logfile, "a") as file:
|
||||||
|
file.write(status + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
batch_convert()
|
|
@ -0,0 +1,114 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||||
|
2.1. It will:
|
||||||
|
|
||||||
|
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||||
|
- Check consistency between these new stats and the old ones.
|
||||||
|
- Remove the deprecated `stats.json`.
|
||||||
|
- Update codebase_version in `info.json`.
|
||||||
|
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
|
||||||
|
--repo-id=aliberts/koch_tutorial
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
|
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||||
|
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||||
|
|
||||||
|
V20 = "v2.0"
|
||||||
|
V21 = "v2.1"
|
||||||
|
|
||||||
|
|
||||||
|
class SuppressWarnings:
|
||||||
|
def __enter__(self):
|
||||||
|
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||||
|
logging.getLogger().setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
logging.getLogger().setLevel(self.previous_level)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dataset(
|
||||||
|
repo_id: str,
|
||||||
|
branch: str | None = None,
|
||||||
|
num_workers: int = 4,
|
||||||
|
):
|
||||||
|
with SuppressWarnings():
|
||||||
|
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||||
|
|
||||||
|
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||||
|
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||||
|
|
||||||
|
convert_stats(dataset, num_workers=num_workers)
|
||||||
|
ref_stats = load_stats(dataset.root)
|
||||||
|
check_aggregate_stats(dataset, ref_stats)
|
||||||
|
|
||||||
|
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||||
|
write_info(dataset.meta.info, dataset.root)
|
||||||
|
|
||||||
|
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||||
|
|
||||||
|
# delete old stats.json file
|
||||||
|
if (dataset.root / STATS_PATH).is_file:
|
||||||
|
(dataset.root / STATS_PATH).unlink()
|
||||||
|
|
||||||
|
hub_api = HfApi()
|
||||||
|
if hub_api.file_exists(
|
||||||
|
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||||
|
):
|
||||||
|
hub_api.delete_file(
|
||||||
|
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||||
|
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--branch",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Repo branch to push your dataset. Defaults to the main branch.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_dataset(**vars(args))
|
|
@ -0,0 +1,99 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.utils import write_episode_stats
|
||||||
|
|
||||||
|
|
||||||
|
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||||
|
ep_len = dataset.meta.episodes[episode_index]["length"]
|
||||||
|
sampled_indices = sample_indices(ep_len)
|
||||||
|
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
||||||
|
video_frames = dataset._query_videos(query_timestamps, episode_index)
|
||||||
|
return video_frames[ft_key].numpy()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||||
|
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
||||||
|
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
||||||
|
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||||
|
|
||||||
|
ep_stats = {}
|
||||||
|
for key, ft in dataset.features.items():
|
||||||
|
if ft["dtype"] == "video":
|
||||||
|
# We sample only for videos
|
||||||
|
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
|
||||||
|
else:
|
||||||
|
ep_ft_data = np.array(ep_data[key])
|
||||||
|
|
||||||
|
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||||
|
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||||
|
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||||
|
|
||||||
|
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||||
|
ep_stats[key] = {
|
||||||
|
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||||
|
|
||||||
|
|
||||||
|
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
||||||
|
assert dataset.episodes is None
|
||||||
|
print("Computing episodes stats")
|
||||||
|
total_episodes = dataset.meta.total_episodes
|
||||||
|
if num_workers > 0:
|
||||||
|
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||||
|
futures = {
|
||||||
|
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
|
||||||
|
for ep_idx in range(total_episodes)
|
||||||
|
}
|
||||||
|
for future in tqdm(as_completed(futures), total=total_episodes):
|
||||||
|
future.result()
|
||||||
|
else:
|
||||||
|
for ep_idx in tqdm(range(total_episodes)):
|
||||||
|
convert_episode_stats(dataset, ep_idx)
|
||||||
|
|
||||||
|
for ep_idx in tqdm(range(total_episodes)):
|
||||||
|
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||||
|
|
||||||
|
|
||||||
|
def check_aggregate_stats(
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
reference_stats: dict[str, dict[str, np.ndarray]],
|
||||||
|
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
|
||||||
|
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
|
||||||
|
):
|
||||||
|
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
|
||||||
|
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
|
||||||
|
for key, ft in dataset.features.items():
|
||||||
|
# These values might need some fine-tuning
|
||||||
|
if ft["dtype"] == "video":
|
||||||
|
# to account for image sub-sampling
|
||||||
|
rtol, atol = video_rtol_atol
|
||||||
|
else:
|
||||||
|
rtol, atol = default_rtol_atol
|
||||||
|
|
||||||
|
for stat, val in agg_stats[key].items():
|
||||||
|
if key in reference_stats and stat in reference_stats[key]:
|
||||||
|
err_msg = f"feature='{key}' stats='{stat}'"
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
||||||
|
)
|
|
@ -27,6 +27,35 @@ import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from datasets.features.features import register_feature
|
from datasets.features.features import register_feature
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video_frames(
|
||||||
|
video_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
tolerance_s: float,
|
||||||
|
backend: str = "torchcodec",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Decodes video frames using the specified backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path (Path): Path to the video file.
|
||||||
|
timestamps (list[float]): List of timestamps to extract frames.
|
||||||
|
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||||
|
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Decoded frames.
|
||||||
|
|
||||||
|
Currently supports torchcodec on cpu and pyav.
|
||||||
|
"""
|
||||||
|
if backend == "torchcodec":
|
||||||
|
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||||
|
elif backend in ["pyav", "video_reader"]:
|
||||||
|
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames_torchvision(
|
def decode_video_frames_torchvision(
|
||||||
|
@ -69,11 +98,11 @@ def decode_video_frames_torchvision(
|
||||||
|
|
||||||
# set the first and last requested timestamps
|
# set the first and last requested timestamps
|
||||||
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
||||||
first_ts = timestamps[0]
|
first_ts = min(timestamps)
|
||||||
last_ts = timestamps[-1]
|
last_ts = max(timestamps)
|
||||||
|
|
||||||
# access closest key frame of the first requested frame
|
# access closest key frame of the first requested frame
|
||||||
# Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||||
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
||||||
reader.seek(first_ts, keyframes_only=keyframes_only)
|
reader.seek(first_ts, keyframes_only=keyframes_only)
|
||||||
|
|
||||||
|
@ -127,6 +156,75 @@ def decode_video_frames_torchvision(
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video_frames_torchcodec(
|
||||||
|
video_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
tolerance_s: float,
|
||||||
|
device: str = "cpu",
|
||||||
|
log_loaded_timestamps: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||||
|
|
||||||
|
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
||||||
|
|
||||||
|
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||||
|
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||||
|
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
||||||
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||||
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||||
|
"""
|
||||||
|
# initialize video decoder
|
||||||
|
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
||||||
|
loaded_frames = []
|
||||||
|
loaded_ts = []
|
||||||
|
# get metadata for frame information
|
||||||
|
metadata = decoder.metadata
|
||||||
|
average_fps = metadata.average_fps
|
||||||
|
|
||||||
|
# convert timestamps to frame indices
|
||||||
|
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||||
|
|
||||||
|
# retrieve frames based on indices
|
||||||
|
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||||
|
|
||||||
|
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
||||||
|
loaded_frames.append(frame)
|
||||||
|
loaded_ts.append(pts.item())
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||||
|
|
||||||
|
query_ts = torch.tensor(timestamps)
|
||||||
|
loaded_ts = torch.tensor(loaded_ts)
|
||||||
|
|
||||||
|
# compute distances between each query timestamp and loaded timestamps
|
||||||
|
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
||||||
|
min_, argmin_ = dist.min(1)
|
||||||
|
|
||||||
|
is_within_tol = min_ < tolerance_s
|
||||||
|
assert is_within_tol.all(), (
|
||||||
|
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||||
|
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||||
|
"This might be due to synchronization issues with timestamps during data collection."
|
||||||
|
"To be safe, we advise to ignore this item during training."
|
||||||
|
f"\nqueried timestamps: {query_ts}"
|
||||||
|
f"\nloaded timestamps: {loaded_ts}"
|
||||||
|
f"\nvideo: {video_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# get closest frames to the query timestamps
|
||||||
|
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||||
|
closest_ts = loaded_ts[argmin_]
|
||||||
|
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(f"{closest_ts=}")
|
||||||
|
|
||||||
|
# convert to float32 in [0,1] range (channel first)
|
||||||
|
closest_frames = closest_frames.type(torch.float32) / 255
|
||||||
|
|
||||||
|
assert len(timestamps) == len(closest_frames)
|
||||||
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
def encode_video_frames(
|
def encode_video_frames(
|
||||||
imgs_dir: Path | str,
|
imgs_dir: Path | str,
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
|
|
|
@ -1 +1,15 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
|
@ -37,12 +37,12 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
||||||
Args:
|
Args:
|
||||||
cfg (EnvConfig): the config of the environment to instantiate.
|
cfg (EnvConfig): the config of the environment to instantiate.
|
||||||
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
|
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
|
||||||
use_async_envs (bool, optional): Wether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
|
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
|
||||||
False.
|
False.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if n_envs < 1
|
ValueError: if n_envs < 1
|
||||||
ModuleNotFoundError: If the requested env package is not intalled
|
ModuleNotFoundError: If the requested env package is not installed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
gym.vector.VectorEnv: The parallelized gym.env instance.
|
gym.vector.VectorEnv: The parallelized gym.env instance.
|
||||||
|
|
|
@ -1 +1,15 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .optimizers import OptimizerConfig as OptimizerConfig
|
from .optimizers import OptimizerConfig as OptimizerConfig
|
||||||
|
|
|
@ -1,6 +1,20 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .act.configuration_act import ACTConfig as ACTConfig
|
from .act.configuration_act import ACTConfig as ACTConfig
|
||||||
|
from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig
|
||||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||||
from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig
|
|
||||||
|
|
|
@ -64,7 +64,7 @@ class ACTConfig(PreTrainedConfig):
|
||||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||||
original scale. Note that this is also used for normalizing the training targets.
|
original scale. Note that this is also used for normalizing the training targets.
|
||||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||||
`None` means no pretrained weights.
|
`None` means no pretrained weights.
|
||||||
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
||||||
convolution.
|
convolution.
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -14,31 +13,29 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Qwen2VL model configuration"""
|
"""Qwen2VL model configuration"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
from lerobot.common.optim.optimizers import AdamWConfig
|
from lerobot.common.optim.optimizers import AdamWConfig
|
||||||
from lerobot.common.optim.schedulers import (
|
from lerobot.common.optim.schedulers import (
|
||||||
CosineDecayWithWarmupSchedulerConfig,
|
CosineDecayWithWarmupSchedulerConfig,
|
||||||
ConstantWithWarmupSchedulerConfig
|
ConstantWithWarmupSchedulerConfig
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.common.policies.dexvla.policy_heads.configuration_scaledp import ScaleDPPolicyConfig
|
|
||||||
from lerobot.common.policies.dexvla.policy_heads.configuration_unet_diffusion import UnetDiffusionPolicyConfig
|
|
||||||
from lerobot.common.policies.dexvla.qwe2_vla.configuration_qwen2_vla import Qwen2VLAConfig
|
|
||||||
from lerobot.configs.types import NormalizationMode
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("dexvla")
|
@PreTrainedConfig.register_subclass("dexvla")
|
||||||
@dataclass
|
@dataclass
|
||||||
class DexVLAConfig(PreTrainedConfig):
|
class DexVLAConfig(PreTrainedConfig):
|
||||||
# For loading policy head
|
# For loading policy head
|
||||||
policy_head_type: str = 'scale_dp_policy'
|
policy_head_type: str = "scale_dp_policy"
|
||||||
policy_head_size: str = 'ScaleDP_L'
|
policy_head_size: str = "ScaleDP_L"
|
||||||
action_dim: int = 14
|
action_dim: int = 14
|
||||||
state_dim: int = 14
|
state_dim: int = 14
|
||||||
chunk_size: int = 50
|
chunk_size: int = 50
|
||||||
|
@ -86,33 +83,37 @@ class DexVLAConfig(PreTrainedConfig):
|
||||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||||
)
|
)
|
||||||
if self.using_reasoning:
|
if self.using_reasoning:
|
||||||
assert self.using_film, f"using_reasoning requires `using_film=True`"
|
assert self.using_film, "using_reasoning requires `using_film=True`"
|
||||||
assert self.with_llm_head, f"using_reasoning requires `with_llm_head=True`"
|
assert self.with_llm_head, "using_reasoning requires `with_llm_head=True`"
|
||||||
print("You have set using_reasoning=True, please make sure your data has key 'reasoning'.")
|
print("You have set using_reasoning=True, please make sure your data has key 'reasoning'.")
|
||||||
else:
|
else:
|
||||||
print(f"Warning:DexVLA recommends to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'.")
|
print(
|
||||||
|
"Warning:DexVLA recommends to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'."
|
||||||
|
)
|
||||||
|
|
||||||
if self.qwen2_vl_path is None:
|
if self.qwen2_vl_path is None:
|
||||||
raise ValueError("DexVLA is built on official qwen2_vl-2B. You have to download the official weights of qwen2_vl-2B first and set 'qwen2_vl_path'.")
|
raise ValueError(
|
||||||
|
"DexVLA is built on official qwen2_vl-2B. You have to download the official weights of qwen2_vl-2B first and set 'qwen2_vl_path'."
|
||||||
|
)
|
||||||
|
|
||||||
if self.policy_head_type == 'scale_dp_policy':
|
if self.policy_head_type == "scale_dp_policy":
|
||||||
self.policy_head_config = AutoConfig.for_model(
|
self.policy_head_config = AutoConfig.for_model(
|
||||||
model_type=self.policy_head_type,
|
model_type=self.policy_head_type,
|
||||||
model_size=self.policy_head_size,
|
model_size=self.policy_head_size,
|
||||||
cond_dim=self.hidden_size,
|
cond_dim=self.hidden_size,
|
||||||
action_dim=self.action_dim,
|
action_dim=self.action_dim,
|
||||||
prediction_horizon=self.chunk_size,
|
prediction_horizon=self.chunk_size,
|
||||||
state_dim=self.state_dim
|
state_dim=self.state_dim,
|
||||||
)
|
)
|
||||||
elif self.policy_head_type == 'unet_diffusion':
|
elif self.policy_head_type == "unet_diffusion":
|
||||||
self.policy_head_config = AutoConfig.for_model(
|
self.policy_head_config = AutoConfig.for_model(
|
||||||
model_type=self.policy_head_type,
|
model_type=self.policy_head_type,
|
||||||
global_cond_dim=self.hidden_size,
|
global_cond_dim=self.hidden_size,
|
||||||
action_dim=self.action_dim,
|
action_dim=self.action_dim,
|
||||||
state_dim=self.state_dim
|
state_dim=self.state_dim,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Policy head type {self.policy_head_type} not supported')
|
raise ValueError(f"Policy head type {self.policy_head_type} not supported")
|
||||||
|
|
||||||
if self.training_stage not in [2,3]:
|
if self.training_stage not in [2,3]:
|
||||||
raise ValueError(f"Training stage must be 2 or 3. Got {self.training_stage}.")
|
raise ValueError(f"Training stage must be 2 or 3. Got {self.training_stage}.")
|
||||||
|
@ -164,6 +165,3 @@ class DexVLAConfig(PreTrainedConfig):
|
||||||
@property
|
@property
|
||||||
def reward_delta_indices(self) -> None:
|
def reward_delta_indices(self) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class ActionProjector(nn.Module):
|
class ActionProjector(nn.Module):
|
||||||
def __init__(self, in_dim, out_dim=1024):
|
def __init__(self, in_dim, out_dim=1024):
|
||||||
super(ActionProjector, self).__init__()
|
super().__init__()
|
||||||
self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
|
self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
|
||||||
self.mlps = nn.ModuleList([
|
self.mlps = nn.ModuleList(
|
||||||
|
[
|
||||||
# nn.LayerNorm(in_dim),
|
# nn.LayerNorm(in_dim),
|
||||||
nn.Linear(in_dim, in_dim),
|
nn.Linear(in_dim, in_dim),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
|
@ -22,7 +24,7 @@ class ActionProjector(nn.Module):
|
||||||
|
|
||||||
class FiLM(nn.Module):
|
class FiLM(nn.Module):
|
||||||
def __init__(self, feature_dim, condition_dim):
|
def __init__(self, feature_dim, condition_dim):
|
||||||
super(FiLM, self).__init__()
|
super().__init__()
|
||||||
self.scale_fc = nn.Linear(condition_dim, feature_dim)
|
self.scale_fc = nn.Linear(condition_dim, feature_dim)
|
||||||
self.shift_fc = nn.Linear(condition_dim, feature_dim)
|
self.shift_fc = nn.Linear(condition_dim, feature_dim)
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
|
||||||
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
|
|
||||||
from lerobot.common.policies.dexvla.qwe2_vla.modeling_qwen2_vla import (
|
|
||||||
Qwen2VLForConditionalGenerationForVLA
|
|
||||||
)
|
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from lerobot.common.policies.dexvla.policy_heads.modeling_unet_diffusion import ConditionalUnet1D
|
|
||||||
from lerobot.common.policies.dexvla.policy_heads.modeling_scaledp import ScaleDP
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from torch import Tensor
|
||||||
|
from transformers import AutoProcessor, AutoTokenizer
|
||||||
|
|
||||||
|
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
|
||||||
|
from lerobot.common.policies.dexvla.qwe2_vla.modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA
|
||||||
from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess
|
from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess
|
||||||
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
import os
|
import os
|
||||||
|
@ -45,7 +47,7 @@ class DexVLAPolicy(PreTrainedPolicy):
|
||||||
config.output_features, config.normalization_mapping, dataset_stats
|
config.output_features, config.normalization_mapping, dataset_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
for k in ['using_film', 'llm_loss_weight', 'with_llm_head', 'policy_head_config']:
|
for k in ["using_film", "llm_loss_weight", "with_llm_head", "policy_head_config"]:
|
||||||
setattr(config.qwen2_vla_config, k, config.__dict__[k])
|
setattr(config.qwen2_vla_config, k, config.__dict__[k])
|
||||||
|
|
||||||
# if self.config.training_stage == 2:
|
# if self.config.training_stage == 2:
|
||||||
|
@ -82,10 +84,10 @@ class DexVLAPolicy(PreTrainedPolicy):
|
||||||
self.model.requires_grad_(False)
|
self.model.requires_grad_(False)
|
||||||
self.model.policy_head.requires_grad_(True)
|
self.model.policy_head.requires_grad_(True)
|
||||||
self.qwen2_vl_processor = AutoProcessor.from_pretrained(config.qwen2_vl_path)
|
self.qwen2_vl_processor = AutoProcessor.from_pretrained(config.qwen2_vl_path)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(config.qwen2_vl_path)
|
||||||
config.qwen2_vl_path
|
self.vla_processor = Qwen2VLAProcess(
|
||||||
)
|
tokenizer=self.tokenizer, multimodal_processor=self.qwen2_vl_processor
|
||||||
self.vla_processor = Qwen2VLAProcess(tokenizer=self.tokenizer, multimodal_processor=self.qwen2_vl_processor) # process the input data into VLM format
|
) # process the input data into VLM format
|
||||||
|
|
||||||
self.resize_size = self.config.resize_size
|
self.resize_size = self.config.resize_size
|
||||||
ratio = 0.95
|
ratio = 0.95
|
||||||
|
@ -104,14 +106,14 @@ class DexVLAPolicy(PreTrainedPolicy):
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||||
task_descs = batch['task']
|
task_descs = batch["task"]
|
||||||
try:
|
try:
|
||||||
reasonings = batch['reasoning']
|
reasonings = batch["reasoning"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
reasonings = ['no reasoning'] * len(task_descs)
|
reasonings = ["no reasoning"] * len(task_descs)
|
||||||
|
|
||||||
pass
|
pass
|
||||||
is_pad = batch['action_is_pad']
|
is_pad = batch["action_is_pad"]
|
||||||
all_cam_images = []
|
all_cam_images = []
|
||||||
for k in present_img_keys:
|
for k in present_img_keys:
|
||||||
all_cam_images.append(batch[k])
|
all_cam_images.append(batch[k])
|
||||||
|
@ -120,8 +122,8 @@ class DexVLAPolicy(PreTrainedPolicy):
|
||||||
image_data = torch.stack(all_cam_images) * 255
|
image_data = torch.stack(all_cam_images) * 255
|
||||||
image_data = image_data.to(dtype=torch.uint8)
|
image_data = image_data.to(dtype=torch.uint8)
|
||||||
# construct observations
|
# construct observations
|
||||||
qpos_data = batch['observation.state'].float()
|
qpos_data = batch["observation.state"].float()
|
||||||
action_data = batch['action'].float()
|
action_data = batch["action"].float()
|
||||||
|
|
||||||
orig_shape = image_data.shape
|
orig_shape = image_data.shape
|
||||||
image_data = image_data.view(-1, *orig_shape[2:])
|
image_data = image_data.view(-1, *orig_shape[2:])
|
||||||
|
@ -131,29 +133,24 @@ class DexVLAPolicy(PreTrainedPolicy):
|
||||||
|
|
||||||
image_data = image_data.view(*orig_shape[:3], *self.resize_size)
|
image_data = image_data.view(*orig_shape[:3], *self.resize_size)
|
||||||
|
|
||||||
vl_data = {
|
vl_data = {"images": image_data, "raw_langs": task_descs, "reasonings": reasonings}
|
||||||
'images': image_data,
|
|
||||||
'raw_langs': task_descs,
|
|
||||||
'reasonings': reasonings
|
|
||||||
}
|
|
||||||
# processing vl_data into qwen2_vl format
|
# processing vl_data into qwen2_vl format
|
||||||
vla_inputs = self.vla_processor.forward(vl_data, use_reasoning=self.config.using_reasoning)
|
vla_inputs = self.vla_processor.forward(vl_data, use_reasoning=self.config.using_reasoning)
|
||||||
vla_inputs['states'] = qpos_data
|
vla_inputs["states"] = qpos_data
|
||||||
vla_inputs['is_pad'] = is_pad
|
vla_inputs["is_pad"] = is_pad
|
||||||
vla_inputs['actions'] = action_data
|
vla_inputs["actions"] = action_data
|
||||||
return vla_inputs
|
return vla_inputs
|
||||||
|
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||||
|
|
||||||
processed_batch = self.process_batch(batch)
|
processed_batch = self.process_batch(batch)
|
||||||
|
|
||||||
ret = self.model.forward(**processed_batch)
|
ret = self.model.forward(**processed_batch)
|
||||||
loss_dict = ret['loss']
|
loss_dict = ret["loss"]
|
||||||
loss = loss_dict['loss'].mean()
|
loss = loss_dict["loss"].mean()
|
||||||
return loss, loss_dict
|
return loss, loss_dict
|
||||||
|
|
||||||
def dexvla_predict_action(self,
|
def dexvla_predict_action(
|
||||||
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
actions=None,
|
actions=None,
|
||||||
states=None,
|
states=None,
|
||||||
|
@ -162,15 +159,15 @@ class DexVLAPolicy(PreTrainedPolicy):
|
||||||
is_eval=True,
|
is_eval=True,
|
||||||
pixel_values=None,
|
pixel_values=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
image_grid_thw=None,
|
image_grid_spatiotemporal=None,
|
||||||
):
|
):
|
||||||
input_ids = input_ids.to('cuda')
|
input_ids = input_ids.to("cuda")
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
outputs = self.model.generate(
|
outputs = self.model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
image_grid_thw=image_grid_thw,
|
image_grid_spatiotemporal=image_grid_spatiotemporal,
|
||||||
is_eval=is_eval,
|
is_eval=is_eval,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
|
@ -188,7 +185,7 @@ class DexVLAPolicy(PreTrainedPolicy):
|
||||||
input_token_len = input_ids.shape[1]
|
input_token_len = input_ids.shape[1]
|
||||||
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
||||||
if n_diff_input_output > 0:
|
if n_diff_input_output > 0:
|
||||||
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids")
|
||||||
outputs_text = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=False)[0]
|
outputs_text = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=False)[0]
|
||||||
|
|
||||||
outputs_text = outputs_text.strip()
|
outputs_text = outputs_text.strip()
|
||||||
|
@ -198,14 +195,19 @@ class DexVLAPolicy(PreTrainedPolicy):
|
||||||
action_hidden_states = None
|
action_hidden_states = None
|
||||||
|
|
||||||
if self.model.using_film:
|
if self.model.using_film:
|
||||||
action_hidden_states = self.model.film_forward(labels=torch.ones_like(output_ids),
|
action_hidden_states = self.model.film_forward(
|
||||||
|
labels=torch.ones_like(output_ids),
|
||||||
input_ids=output_ids,
|
input_ids=output_ids,
|
||||||
hidden_states=torch.cat(last_hidden_states, dim=1))
|
hidden_states=torch.cat(last_hidden_states, dim=1),
|
||||||
|
)
|
||||||
|
|
||||||
action = self.model.policy_head(actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad)
|
action = self.model.policy_head(
|
||||||
|
actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad
|
||||||
|
)
|
||||||
return action, outputs_text
|
return action, outputs_text
|
||||||
|
|
||||||
def tinyvla_predict_action(self,
|
def tinyvla_predict_action(
|
||||||
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
actions=None,
|
actions=None,
|
||||||
states=None,
|
states=None,
|
||||||
|
@ -213,20 +215,24 @@ class DexVLAPolicy(PreTrainedPolicy):
|
||||||
is_eval=True,
|
is_eval=True,
|
||||||
pixel_values=None,
|
pixel_values=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
image_grid_thw=None,
|
image_grid_spatiotemporal=None,
|
||||||
):
|
):
|
||||||
input_ids = input_ids.to('cuda')
|
input_ids = input_ids.to("cuda")
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
all_hidden_states = self.model.forward(input_ids,
|
all_hidden_states = self.model.forward(
|
||||||
|
input_ids,
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
image_grid_thw=image_grid_thw,
|
image_grid_spatiotemporal=image_grid_spatiotemporal,
|
||||||
is_eval=is_eval,
|
is_eval=is_eval,
|
||||||
tinyvla=True)
|
tinyvla=True,
|
||||||
|
)
|
||||||
|
|
||||||
all_hidden_states = torch.mean(all_hidden_states, dim=1).unsqueeze(1)
|
all_hidden_states = torch.mean(all_hidden_states, dim=1).unsqueeze(1)
|
||||||
|
|
||||||
action = self.model.policy_head(actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad)
|
action = self.model.policy_head(
|
||||||
|
actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad
|
||||||
|
)
|
||||||
return action, "tinyvla generates no reasoning"
|
return action, "tinyvla generates no reasoning"
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
@ -250,7 +256,7 @@ class DexVLAPolicy(PreTrainedPolicy):
|
||||||
if len(self._action_queue) == 0:
|
if len(self._action_queue) == 0:
|
||||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||||
try:
|
try:
|
||||||
task_descs = batch['task']
|
task_descs = batch["task"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
task_descs = " "
|
task_descs = " "
|
||||||
print("No task descriptions found for this task")
|
print("No task descriptions found for this task")
|
||||||
|
@ -263,7 +269,7 @@ class DexVLAPolicy(PreTrainedPolicy):
|
||||||
image_data = torch.stack(all_cam_images) * 255
|
image_data = torch.stack(all_cam_images) * 255
|
||||||
image_data = image_data.to(dtype=torch.uint8)
|
image_data = image_data.to(dtype=torch.uint8)
|
||||||
# construct observations
|
# construct observations
|
||||||
qpos_data = batch['observation.state'].float()
|
qpos_data = batch["observation.state"].float()
|
||||||
|
|
||||||
image_data = image_data.squeeze(0)
|
image_data = image_data.squeeze(0)
|
||||||
|
|
||||||
|
@ -271,11 +277,15 @@ class DexVLAPolicy(PreTrainedPolicy):
|
||||||
image_data = transform(image_data)
|
image_data = transform(image_data)
|
||||||
|
|
||||||
# processing vl_data into qwen2_vl format
|
# processing vl_data into qwen2_vl format
|
||||||
vla_inputs = self.vla_processor.single_forward_process(images=image_data, raw_lang=task_descs, reasoning=None, eval=True)
|
vla_inputs = self.vla_processor.single_forward_process(
|
||||||
vla_inputs['states'] = qpos_data
|
images=image_data, raw_lang=task_descs, reasoning=None, eval=True
|
||||||
|
)
|
||||||
|
vla_inputs["states"] = qpos_data
|
||||||
|
|
||||||
if self.config.using_film and self.config.with_llm_head: # dexvla
|
if self.config.using_film and self.config.with_llm_head: # dexvla
|
||||||
all_actions, outputs = self.dexvla_predict_action(**vla_inputs, is_eval=True, tokenizer=self.tokenizer)
|
all_actions, outputs = self.dexvla_predict_action(
|
||||||
|
**vla_inputs, is_eval=True, tokenizer=self.tokenizer
|
||||||
|
)
|
||||||
else: # tinyvla
|
else: # tinyvla
|
||||||
all_actions, outputs = self.tinyvla_predict_action(**vla_inputs, is_eval=True)
|
all_actions, outputs = self.tinyvla_predict_action(**vla_inputs, is_eval=True)
|
||||||
|
|
||||||
|
@ -283,8 +293,3 @@ class DexVLAPolicy(PreTrainedPolicy):
|
||||||
self._action_queue.extend(actions.transpose(0, 1))
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
|
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,21 +1,32 @@
|
||||||
import os
|
import os
|
||||||
from typing import Union, List
|
from typing import Union
|
||||||
from transformers import PretrainedConfig
|
|
||||||
|
|
||||||
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
MODEL_STRUCTURE = {
|
MODEL_STRUCTURE = {
|
||||||
'ScaleDP_H': {'depth': 32, 'n_emb': 1280, 'num_heads': 16, },
|
"scaledp_h": {
|
||||||
'ScaleDP_L': {'depth': 24, 'n_emb': 1024, 'num_heads': 16, }, # 400M
|
"depth": 32,
|
||||||
|
"n_emb": 1280,
|
||||||
|
"num_heads": 16,
|
||||||
|
},
|
||||||
|
"scaledp_l": {
|
||||||
|
"depth": 24,
|
||||||
|
"n_emb": 1024,
|
||||||
|
"num_heads": 16,
|
||||||
|
}, # 400M
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ScaleDPPolicyConfig(PretrainedConfig):
|
class ScaleDPPolicyConfig(PretrainedConfig):
|
||||||
'''
|
"""
|
||||||
Configuration for ScaleDP policy head
|
Configuration for ScaleDP policy head
|
||||||
'''
|
"""
|
||||||
|
|
||||||
model_type = "scale_dp_policy"
|
model_type = "scale_dp_policy"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
eval: bool = False,
|
eval: bool = False,
|
||||||
|
@ -36,12 +47,12 @@ class ScaleDPPolicyConfig(PretrainedConfig):
|
||||||
num_inference_timesteps: int = 10,
|
num_inference_timesteps: int = 10,
|
||||||
noise_samples: int = 1,
|
noise_samples: int = 1,
|
||||||
num_train_timesteps: int = 100,
|
num_train_timesteps: int = 100,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
if model_size != "none":
|
if model_size != "none":
|
||||||
depth = MODEL_STRUCTURE[model_size]['depth']
|
depth = MODEL_STRUCTURE[model_size]["depth"]
|
||||||
n_emb = MODEL_STRUCTURE[model_size]['n_emb']
|
n_emb = MODEL_STRUCTURE[model_size]["n_emb"]
|
||||||
num_heads = MODEL_STRUCTURE[model_size]['num_heads']
|
num_heads = MODEL_STRUCTURE[model_size]["num_heads"]
|
||||||
else:
|
else:
|
||||||
# raise ValueError("model_size show not be 'none'")
|
# raise ValueError("model_size show not be 'none'")
|
||||||
pass
|
pass
|
||||||
|
@ -52,7 +63,6 @@ class ScaleDPPolicyConfig(PretrainedConfig):
|
||||||
self.output_dim = action_dim
|
self.output_dim = action_dim
|
||||||
self.prediction_horizon = prediction_horizon
|
self.prediction_horizon = prediction_horizon
|
||||||
|
|
||||||
|
|
||||||
self.cond_dim = cond_dim
|
self.cond_dim = cond_dim
|
||||||
self.state_dim = state_dim
|
self.state_dim = state_dim
|
||||||
|
|
||||||
|
@ -72,7 +82,9 @@ class ScaleDPPolicyConfig(PretrainedConfig):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
def from_pretrained(
|
||||||
|
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||||
|
) -> "PretrainedConfig":
|
||||||
cls._set_token_in_kwargs(kwargs)
|
cls._set_token_in_kwargs(kwargs)
|
||||||
|
|
||||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||||
|
@ -81,7 +93,11 @@ class ScaleDPPolicyConfig(PretrainedConfig):
|
||||||
if config_dict.get("model_type") == "llava_pythia":
|
if config_dict.get("model_type") == "llava_pythia":
|
||||||
config_dict = config_dict["action_head"]
|
config_dict = config_dict["action_head"]
|
||||||
|
|
||||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
if (
|
||||||
|
"model_type" in config_dict
|
||||||
|
and hasattr(cls, "model_type")
|
||||||
|
and config_dict["model_type"] != cls.model_type
|
||||||
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||||
|
@ -89,4 +105,5 @@ class ScaleDPPolicyConfig(PretrainedConfig):
|
||||||
|
|
||||||
return cls.from_dict(config_dict, **kwargs)
|
return cls.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig)
|
AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig)
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
import os
|
import os
|
||||||
from typing import Union, List
|
from typing import Union
|
||||||
from transformers import PretrainedConfig
|
|
||||||
|
|
||||||
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UnetDiffusionPolicyConfig(PretrainedConfig):
|
class UnetDiffusionPolicyConfig(PretrainedConfig):
|
||||||
'''
|
"""
|
||||||
Configuration for dit diffusion policy head
|
Configuration for dit diffusion policy head
|
||||||
'''
|
"""
|
||||||
|
|
||||||
model_type = "unet_diffusion_policy"
|
model_type = "unet_diffusion_policy"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -17,7 +19,7 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
|
||||||
action_dim=10,
|
action_dim=10,
|
||||||
global_cond_dim=2048,
|
global_cond_dim=2048,
|
||||||
diffusion_step_embed_dim=256,
|
diffusion_step_embed_dim=256,
|
||||||
down_dims=[256, 512, 1024],
|
down_dims=None,
|
||||||
kernel_size=5,
|
kernel_size=5,
|
||||||
n_groups=8,
|
n_groups=8,
|
||||||
state_dim=7,
|
state_dim=7,
|
||||||
|
@ -25,8 +27,10 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
|
||||||
noise_samples=1,
|
noise_samples=1,
|
||||||
num_inference_timesteps=10,
|
num_inference_timesteps=10,
|
||||||
num_train_timesteps=100,
|
num_train_timesteps=100,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if down_dims is None:
|
||||||
|
down_dims = [256, 512, 1024]
|
||||||
self.input_dim = action_dim
|
self.input_dim = action_dim
|
||||||
self.noise_samples = noise_samples
|
self.noise_samples = noise_samples
|
||||||
self.prediction_horizon = prediction_horizon
|
self.prediction_horizon = prediction_horizon
|
||||||
|
@ -42,7 +46,9 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
def from_pretrained(
|
||||||
|
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||||
|
) -> "PretrainedConfig":
|
||||||
cls._set_token_in_kwargs(kwargs)
|
cls._set_token_in_kwargs(kwargs)
|
||||||
|
|
||||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||||
|
@ -51,7 +57,11 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
|
||||||
if config_dict.get("model_type") == "llava_pythia":
|
if config_dict.get("model_type") == "llava_pythia":
|
||||||
config_dict = config_dict["action_head"]
|
config_dict = config_dict["action_head"]
|
||||||
|
|
||||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
if (
|
||||||
|
"model_type" in config_dict
|
||||||
|
and hasattr(cls, "model_type")
|
||||||
|
and config_dict["model_type"] != cls.model_type
|
||||||
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||||
|
@ -59,4 +69,5 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
|
||||||
|
|
||||||
return cls.from_dict(config_dict, **kwargs)
|
return cls.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig)
|
AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig)
|
||||||
|
|
|
@ -1,27 +1,20 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import timm
|
|
||||||
import numpy as np
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
try:
|
import numpy as np
|
||||||
from typing import Literal
|
|
||||||
except ImportError:
|
|
||||||
from typing_extensions import Literal
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as func
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.jit import Final
|
|
||||||
from timm.models.vision_transformer import Mlp, use_fused_attn
|
from timm.models.vision_transformer import Mlp, use_fused_attn
|
||||||
|
from torch.jit import Final
|
||||||
|
from transformers import AutoModel
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers import AutoModel, AutoModelForCausalLM
|
|
||||||
|
from .configuration_scaledp import ScaleDPPolicyConfig
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -35,12 +28,12 @@ class Attention(nn.Module):
|
||||||
num_heads: int = 8,
|
num_heads: int = 8,
|
||||||
qkv_bias: bool = False,
|
qkv_bias: bool = False,
|
||||||
qk_norm: bool = False,
|
qk_norm: bool = False,
|
||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.0,
|
||||||
proj_drop: float = 0.,
|
proj_drop: float = 0.0,
|
||||||
norm_layer: nn.Module = nn.LayerNorm,
|
norm_layer: nn.Module = nn.LayerNorm,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_dim = dim // num_heads
|
self.head_dim = dim // num_heads
|
||||||
self.scale = self.head_dim**-0.5
|
self.scale = self.head_dim**-0.5
|
||||||
|
@ -54,15 +47,18 @@ class Attention(nn.Module):
|
||||||
self.proj_drop = nn.Dropout(proj_drop)
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, attn_mask=None) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, attn_mask=None) -> torch.Tensor:
|
||||||
B, N, C = x.shape
|
b, n, c = x.shape
|
||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
q, k, v = qkv.unbind(0)
|
q, k, v = qkv.unbind(0)
|
||||||
q, k = self.q_norm(q), self.k_norm(k)
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = F.scaled_dot_product_attention(
|
x = func.scaled_dot_product_attention(
|
||||||
q, k, v, attn_mask=attn_mask,
|
q,
|
||||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
k,
|
||||||
|
v,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -79,7 +75,7 @@ class Attention(nn.Module):
|
||||||
attn_scores += attn_mask
|
attn_scores += attn_mask
|
||||||
|
|
||||||
# Apply softmax to get attention weights (softmax is applied along the last dimension)
|
# Apply softmax to get attention weights (softmax is applied along the last dimension)
|
||||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
attn_weights = func.softmax(attn_scores, dim=-1)
|
||||||
|
|
||||||
# Dropout on attention weights (if dropout is used)
|
# Dropout on attention weights (if dropout is used)
|
||||||
attn_weights = self.attn_drop(attn_weights)
|
attn_weights = self.attn_drop(attn_weights)
|
||||||
|
@ -87,7 +83,7 @@ class Attention(nn.Module):
|
||||||
# Apply attention weights to value tensor (V)
|
# Apply attention weights to value tensor (V)
|
||||||
x = torch.matmul(attn_weights, v)
|
x = torch.matmul(attn_weights, v)
|
||||||
|
|
||||||
x = x.transpose(1, 2).reshape(B, N, C)
|
x = x.transpose(1, 2).reshape(b, n, c)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
|
@ -104,6 +100,7 @@ def modulate(x, shift, scale):
|
||||||
# Embedding Layers for Timesteps and Class Labels #
|
# Embedding Layers for Timesteps and Class Labels #
|
||||||
#################################################################################
|
#################################################################################
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbedder(nn.Module):
|
class TimestepEmbedder(nn.Module):
|
||||||
"""
|
"""
|
||||||
Embeds scalar timesteps into vector representations.
|
Embeds scalar timesteps into vector representations.
|
||||||
|
@ -145,11 +142,11 @@ class TimestepEmbedder(nn.Module):
|
||||||
return t_emb
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#################################################################################
|
#################################################################################
|
||||||
# Core ScaleDP Model #
|
# Core ScaleDP Model #
|
||||||
#################################################################################
|
#################################################################################
|
||||||
|
|
||||||
|
|
||||||
class ScaleDPBlock(nn.Module):
|
class ScaleDPBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
A ScaleDP block with adaptive layer norm zero (adaLN-Zero) conScaleDPioning.
|
A ScaleDP block with adaptive layer norm zero (adaLN-Zero) conScaleDPioning.
|
||||||
|
@ -161,16 +158,20 @@ class ScaleDPBlock(nn.Module):
|
||||||
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
||||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
|
||||||
|
def approx_gelu():
|
||||||
|
return nn.GELU(approximate="tanh")
|
||||||
|
|
||||||
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
||||||
self.adaLN_modulation = nn.Sequential(
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, c, attn_mask=None):
|
def forward(self, x, c, attn_mask=None):
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(
|
||||||
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask) # norm, scale&shift, attn, scale,
|
6, dim=1
|
||||||
|
)
|
||||||
|
x = x + gate_msa.unsqueeze(1) * self.attn(
|
||||||
|
modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask
|
||||||
|
) # norm, scale&shift, attn, scale,
|
||||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -184,10 +185,7 @@ class FinalLayer(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
self.linear = nn.Linear(hidden_size, output_dim, bias=True)
|
self.linear = nn.Linear(hidden_size, output_dim, bias=True)
|
||||||
self.adaLN_modulation = nn.Sequential(
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, c):
|
def forward(self, x, c):
|
||||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
@ -195,12 +193,14 @@ class FinalLayer(nn.Module):
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
from .configuration_scaledp import ScaleDPPolicyConfig
|
|
||||||
class ScaleDP(PreTrainedModel):
|
class ScaleDP(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
Diffusion models with a Transformer backbone.
|
Diffusion models with a Transformer backbone.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = ScaleDPPolicyConfig
|
config_class = ScaleDPPolicyConfig
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ScaleDPPolicyConfig,
|
config: ScaleDPPolicyConfig,
|
||||||
|
@ -209,15 +209,15 @@ class ScaleDP(PreTrainedModel):
|
||||||
# compute number of tokens for main trunk and conScaleDPion encoder
|
# compute number of tokens for main trunk and conScaleDPion encoder
|
||||||
if config.n_obs_steps is None:
|
if config.n_obs_steps is None:
|
||||||
config.n_obs_steps = config.prediction_horizon
|
config.n_obs_steps = config.prediction_horizon
|
||||||
T = config.prediction_horizon
|
t = config.prediction_horizon
|
||||||
T_cond = 1
|
t_cond = 1
|
||||||
if not config.time_as_cond:
|
if not config.time_as_cond:
|
||||||
T += 1
|
t += 1
|
||||||
T_cond -= 1
|
t_cond -= 1
|
||||||
obs_as_cond = config.cond_dim > 0
|
obs_as_cond = config.cond_dim > 0
|
||||||
if obs_as_cond:
|
if obs_as_cond:
|
||||||
assert config.time_as_cond
|
assert config.time_as_cond
|
||||||
T_cond += config.n_obs_steps
|
t_cond += config.n_obs_steps
|
||||||
|
|
||||||
# self.combine = nn.Linear(cond_dim+state_dim, cond_dim)
|
# self.combine = nn.Linear(cond_dim+state_dim, cond_dim)
|
||||||
self.combine = nn.Sequential(
|
self.combine = nn.Sequential(
|
||||||
|
@ -225,7 +225,7 @@ class ScaleDP(PreTrainedModel):
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(1024, 1024),
|
nn.Linear(1024, 1024),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(1024, config.cond_dim)
|
nn.Linear(1024, config.cond_dim),
|
||||||
)
|
)
|
||||||
self.learn_sigma = config.learn_sigma
|
self.learn_sigma = config.learn_sigma
|
||||||
self.input_dim = config.input_dim
|
self.input_dim = config.input_dim
|
||||||
|
@ -241,32 +241,34 @@ class ScaleDP(PreTrainedModel):
|
||||||
# Will use fixed sin-cos embedding:
|
# Will use fixed sin-cos embedding:
|
||||||
self.pos_embed = nn.Parameter(torch.zeros(1, config.prediction_horizon, config.n_emb))
|
self.pos_embed = nn.Parameter(torch.zeros(1, config.prediction_horizon, config.n_emb))
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList(
|
||||||
ScaleDPBlock(config.n_emb, config.num_heads, mlp_ratio=config.mlp_ratio) for _ in range(config.depth)
|
[
|
||||||
])
|
ScaleDPBlock(config.n_emb, config.num_heads, mlp_ratio=config.mlp_ratio)
|
||||||
|
for _ in range(config.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim)
|
self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim)
|
||||||
# self.initialize_weights()
|
# self.initialize_weights()
|
||||||
# constants
|
# constants
|
||||||
self.T = T
|
self.t = t
|
||||||
self.T_cond = T_cond
|
self.t_cond = t_cond
|
||||||
self.prediction_horizon = config.prediction_horizon
|
self.prediction_horizon = config.prediction_horizon
|
||||||
self.time_as_cond = config.time_as_cond
|
self.time_as_cond = config.time_as_cond
|
||||||
self.action_dim = config.output_dim
|
self.action_dim = config.output_dim
|
||||||
self.obs_as_cond = obs_as_cond
|
self.obs_as_cond = obs_as_cond
|
||||||
logger.info(
|
logger.info("number of parameters in ScaleDP: %e", sum(p.numel() for p in self.parameters()))
|
||||||
"number of parameters in ScaleDP: %e", sum(p.numel() for p in self.parameters())
|
|
||||||
)
|
|
||||||
|
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
|
|
||||||
self.num_inference_timesteps = config.num_inference_timesteps
|
self.num_inference_timesteps = config.num_inference_timesteps
|
||||||
# self.proj_to_action = nn.Identity()
|
# self.proj_to_action = nn.Identity()
|
||||||
self.noise_scheduler = DDIMScheduler(
|
self.noise_scheduler = DDIMScheduler(
|
||||||
num_train_timesteps=config.num_train_timesteps, # 100
|
num_train_timesteps=config.num_train_timesteps, # 100
|
||||||
beta_schedule='squaredcos_cap_v2',
|
beta_schedule="squaredcos_cap_v2",
|
||||||
clip_sample=True,
|
clip_sample=True,
|
||||||
set_alpha_to_one=True,
|
set_alpha_to_one=True,
|
||||||
steps_offset=0,
|
steps_offset=0,
|
||||||
prediction_type='epsilon'
|
prediction_type="epsilon",
|
||||||
)
|
)
|
||||||
self.num_queries = config.num_queries # 16
|
self.num_queries = config.num_queries # 16
|
||||||
self.noise_samples = config.noise_samples # 1
|
self.noise_samples = config.noise_samples # 1
|
||||||
|
@ -308,7 +310,6 @@ class ScaleDP(PreTrainedModel):
|
||||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
def get_optim_groups(self, weight_decay: float = 1e-3):
|
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||||
"""
|
"""
|
||||||
This long function is unfortunately doing something very simple and is being very defensive:
|
This long function is unfortunately doing something very simple and is being very defensive:
|
||||||
|
@ -323,8 +324,8 @@ class ScaleDP(PreTrainedModel):
|
||||||
whitelist_weight_modules = (torch.nn.Linear, Attention)
|
whitelist_weight_modules = (torch.nn.Linear, Attention)
|
||||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||||
for mn, m in self.named_modules():
|
for mn, m in self.named_modules():
|
||||||
for pn, p in m.named_parameters():
|
for pn, _p in m.named_parameters():
|
||||||
fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
|
fpn = "{}.{}".format(mn, pn) if mn else pn # full param name
|
||||||
|
|
||||||
if pn.endswith("bias"):
|
if pn.endswith("bias"):
|
||||||
# all biases will not be decayed
|
# all biases will not be decayed
|
||||||
|
@ -340,70 +341,71 @@ class ScaleDP(PreTrainedModel):
|
||||||
no_decay.add(fpn)
|
no_decay.add(fpn)
|
||||||
|
|
||||||
# validate that we considered every parameter
|
# validate that we considered every parameter
|
||||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
param_dict = dict(self.named_parameters())
|
||||||
inter_params = decay & no_decay
|
inter_params = decay & no_decay
|
||||||
union_params = decay | no_decay
|
union_params = decay | no_decay
|
||||||
assert (
|
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
||||||
len(inter_params) == 0
|
str(inter_params)
|
||||||
), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
|
)
|
||||||
assert (
|
assert len(param_dict.keys() - union_params) == 0, (
|
||||||
len(param_dict.keys() - union_params) == 0
|
"parameters {} were not separated into either decay/no_decay set!".format(
|
||||||
), "parameters %s were not separated into either decay/no_decay set!" % (
|
|
||||||
str(param_dict.keys() - union_params),
|
str(param_dict.keys() - union_params),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# create the pytorch optimizer object
|
# create the pytorch optimizer object
|
||||||
optim_groups = [
|
optim_groups = [
|
||||||
{
|
{
|
||||||
"params": [param_dict[pn] for pn in sorted(list(decay))],
|
"params": [param_dict[pn] for pn in sorted(decay)],
|
||||||
"weight_decay": weight_decay,
|
"weight_decay": weight_decay,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
|
"params": [param_dict[pn] for pn in sorted(no_decay)],
|
||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
return optim_groups
|
return optim_groups
|
||||||
|
|
||||||
def configure_optimizers(self,
|
def configure_optimizers(
|
||||||
|
self,
|
||||||
learning_rate: float = 1e-4,
|
learning_rate: float = 1e-4,
|
||||||
weight_decay: float = 1e-3,
|
weight_decay: float = 1e-3,
|
||||||
betas: Tuple[float, float] = (0.9, 0.95)):
|
betas: Tuple[float, float] = (0.9, 0.95),
|
||||||
|
):
|
||||||
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||||
optim_groups, lr=learning_rate, betas=betas
|
|
||||||
)
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
def forward(self, actions, hidden_states, states, is_pad):
|
def forward(self, actions, hidden_states, states, is_pad):
|
||||||
"""
|
"""
|
||||||
Forward pass for the diffusion head.
|
Forward pass for the diffusion head.
|
||||||
:param actions: target actions, shape [B, Ta, D] D:10 = 3+6+1
|
:param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1
|
||||||
:param hidden_states: hidden states from the llava_pythia, as the conScaleDPion for the diffusion, shape [B,Tokens, D] 8 1200 1024
|
:param hidden_states: hidden states from the llava_pythia, as the conScaleDPion for the diffusion, shape [b,Tokens, D] 8 1200 1024
|
||||||
:param states: robot states, shape [B, D]
|
:param states: robot states, shape [b, D]
|
||||||
:return: loss
|
:return: loss
|
||||||
"""
|
"""
|
||||||
if actions is not None: # training time
|
if actions is not None: # training time
|
||||||
B = actions.size(0)
|
b = actions.size(0)
|
||||||
actions = actions[:, : self.num_queries]
|
actions = actions[:, : self.num_queries]
|
||||||
is_pad = is_pad[:, : self.num_queries]
|
is_pad = is_pad[:, : self.num_queries]
|
||||||
num_noise_samples = self.noise_samples
|
num_noise_samples = self.noise_samples
|
||||||
# sample noise to add to actions
|
# sample noise to add to actions
|
||||||
noise = torch.randn([num_noise_samples] + list(actions.shape), device=actions.device,
|
noise = torch.randn(
|
||||||
dtype=actions.dtype) # num_noise, B, Ta, D(1, 2, 16, 14)
|
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
|
||||||
|
) # num_noise, b, Ta, D(1, 2, 16, 14)
|
||||||
# sample a diffusion iteration for each data point
|
# sample a diffusion iteration for each data point
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
0, self.noise_scheduler.config.num_train_timesteps,
|
0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device
|
||||||
(B,), device=actions.device
|
|
||||||
).long()
|
).long()
|
||||||
|
|
||||||
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
|
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
|
||||||
|
|
||||||
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
|
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
|
||||||
# (this is the forward diffusion process)
|
# (this is the forward diffusion process)
|
||||||
noisy_actions = torch.cat([self.noise_scheduler.add_noise(
|
noisy_actions = torch.cat(
|
||||||
actions, noise[i], timesteps)
|
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
|
||||||
for i in range(len(noise))], dim=0) # [num_noise_samples * B, Ta, action_dim]
|
dim=0,
|
||||||
|
) # [num_noise_samples * b, Ta, action_dim]
|
||||||
|
|
||||||
noisy_actions = noisy_actions.to(dtype=actions.dtype)
|
noisy_actions = noisy_actions.to(dtype=actions.dtype)
|
||||||
assert hidden_states.ndim == 3
|
assert hidden_states.ndim == 3
|
||||||
|
@ -413,20 +415,22 @@ class ScaleDP(PreTrainedModel):
|
||||||
is_pad = is_pad.repeat(num_noise_samples, 1)
|
is_pad = is_pad.repeat(num_noise_samples, 1)
|
||||||
states = states.repeat(num_noise_samples, 1)
|
states = states.repeat(num_noise_samples, 1)
|
||||||
|
|
||||||
noise_pred = self.model_forward(noisy_actions, timesteps, global_cond=hidden_states, states=states)
|
noise_pred = self.model_forward(
|
||||||
|
noisy_actions, timesteps, global_cond=hidden_states, states=states
|
||||||
|
)
|
||||||
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
|
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction='none')
|
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none")
|
||||||
loss = (loss * ~is_pad.unsqueeze(-1)).mean()
|
loss = (loss * ~is_pad.unsqueeze(-1)).mean()
|
||||||
# loss_dict['loss'] = loss
|
# loss_dict['loss'] = loss
|
||||||
return {'loss': loss}
|
return {"loss": loss}
|
||||||
# return loss
|
# return loss
|
||||||
else: # inference time
|
else: # inference time
|
||||||
B = 1
|
b = 1
|
||||||
Tp = self.num_queries
|
tp = self.num_queries
|
||||||
action_dim = self.action_dim
|
action_dim = self.action_dim
|
||||||
|
|
||||||
# initialize action from Guassian noise
|
# initialize action from Gaussian noise
|
||||||
noisy_action = torch.randn((B, Tp, action_dim)).cuda()
|
noisy_action = torch.randn((b, tp, action_dim)).cuda()
|
||||||
|
|
||||||
naction = noisy_action.to(dtype=hidden_states.dtype)
|
naction = noisy_action.to(dtype=hidden_states.dtype)
|
||||||
# init scheduler
|
# init scheduler
|
||||||
|
@ -438,9 +442,7 @@ class ScaleDP(PreTrainedModel):
|
||||||
|
|
||||||
# inverse diffusion step (remove noise)
|
# inverse diffusion step (remove noise)
|
||||||
naction = self.noise_scheduler.step(
|
naction = self.noise_scheduler.step(
|
||||||
model_output=noise_pred,
|
model_output=noise_pred, timestep=k, sample=naction
|
||||||
timestep=k,
|
|
||||||
sample=naction
|
|
||||||
).prev_sample
|
).prev_sample
|
||||||
|
|
||||||
return naction
|
return naction
|
||||||
|
@ -462,7 +464,9 @@ class ScaleDP(PreTrainedModel):
|
||||||
t = t[None].to(x.device)
|
t = t[None].to(x.device)
|
||||||
t = t.expand(t.shape[0])
|
t = t.expand(t.shape[0])
|
||||||
|
|
||||||
x = self.x_embedder(x) + self.pos_embed.to(device=x.device, dtype=x.dtype) # (N, T, D), where T = prediction_horizon
|
x = self.x_embedder(x) + self.pos_embed.to(
|
||||||
|
device=x.device, dtype=x.dtype
|
||||||
|
) # (N, T, D), where T = prediction_horizon
|
||||||
t = self.t_embedder(t) # (N, D)
|
t = self.t_embedder(t) # (N, D)
|
||||||
if self.obs_as_cond:
|
if self.obs_as_cond:
|
||||||
global_cond = self.cond_obs_emb(global_cond) # (N, D)
|
global_cond = self.cond_obs_emb(global_cond) # (N, D)
|
||||||
|
@ -474,11 +478,13 @@ class ScaleDP(PreTrainedModel):
|
||||||
x = self.final_layer(x, c) # (N, T, output_dim)
|
x = self.final_layer(x, c) # (N, T, output_dim)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
#################################################################################
|
#################################################################################
|
||||||
# Sine/Cosine Positional Embedding Functions #
|
# Sine/Cosine Positional Embedding Functions #
|
||||||
#################################################################################
|
#################################################################################
|
||||||
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
||||||
|
|
||||||
|
|
||||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
||||||
"""
|
"""
|
||||||
grid_size: int of the grid height and width
|
grid_size: int of the grid height and width
|
||||||
|
@ -516,11 +522,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||||
"""
|
"""
|
||||||
assert embed_dim % 2 == 0
|
assert embed_dim % 2 == 0
|
||||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||||
omega /= embed_dim / 2.
|
omega /= embed_dim / 2.0
|
||||||
omega = 1. / 10000 ** omega # (D/2,)
|
omega = 1.0 / 10000**omega # (D/2,)
|
||||||
|
|
||||||
pos = pos.reshape(-1) # (M,)
|
pos = pos.reshape(-1) # (M,)
|
||||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||||
|
|
||||||
emb_sin = np.sin(out) # (M, D/2)
|
emb_sin = np.sin(out) # (M, D/2)
|
||||||
emb_cos = np.cos(out) # (M, D/2)
|
emb_cos = np.cos(out) # (M, D/2)
|
||||||
|
@ -533,12 +539,13 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||||
# ScaleDP Configs #
|
# ScaleDP Configs #
|
||||||
#################################################################################
|
#################################################################################
|
||||||
|
|
||||||
def ScaleDP_H(**kwargs):
|
|
||||||
|
def scaledp_h(**kwargs):
|
||||||
return ScaleDP(depth=32, n_emb=1280, num_heads=16, **kwargs)
|
return ScaleDP(depth=32, n_emb=1280, num_heads=16, **kwargs)
|
||||||
|
|
||||||
def ScaleDP_L(**kwargs):
|
|
||||||
|
def scaledp_l(**kwargs):
|
||||||
return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs)
|
return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
AutoModel.register(ScaleDPPolicyConfig, ScaleDP)
|
AutoModel.register(ScaleDPPolicyConfig, ScaleDP)
|
||||||
|
|
|
@ -1,24 +1,24 @@
|
||||||
"""
|
"""
|
||||||
Implementation of Diffusion Policy https://diffusion-policy.cs.columbia.edu/ by Cheng Chi
|
Implementation of Diffusion Policy https://diffusion-policy.cs.columbia.edu/ by Cheng Chi
|
||||||
"""
|
"""
|
||||||
from typing import Callable, Union
|
|
||||||
|
import copy
|
||||||
import math
|
import math
|
||||||
from collections import OrderedDict, deque
|
from typing import Union
|
||||||
from packaging.version import parse as parse_version
|
|
||||||
import random
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
# requires diffusers==0.11.1
|
# requires diffusers==0.11.1
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
from diffusers.training_utils import EMAModel
|
from transformers import AutoModel
|
||||||
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers import AutoModel, AutoModelForCausalLM
|
|
||||||
import copy
|
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
|
||||||
|
|
||||||
# =================== UNet for Diffusion ==============
|
# =================== UNet for Diffusion ==============
|
||||||
|
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
def __init__(self, dim, dtype):
|
def __init__(self, dim, dtype):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -54,9 +54,9 @@ class Upsample1d(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Conv1dBlock(nn.Module):
|
class Conv1dBlock(nn.Module):
|
||||||
'''
|
"""
|
||||||
Conv1d --> GroupNorm --> Mish
|
Conv1d --> GroupNorm --> Mish
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -72,46 +72,41 @@ class Conv1dBlock(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ConditionalResidualBlock1D(nn.Module):
|
class ConditionalResidualBlock1D(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8):
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
cond_dim,
|
|
||||||
kernel_size=3,
|
|
||||||
n_groups=8):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
|
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||||
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
|
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# FiLM modulation https://arxiv.org/abs/1709.07871
|
# FiLM modulation https://arxiv.org/abs/1709.07871
|
||||||
# predicts per-channel scale and bias
|
# predicts per-channel scale and bias
|
||||||
cond_channels = out_channels * 2
|
cond_channels = out_channels * 2
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.cond_encoder = nn.Sequential(
|
self.cond_encoder = nn.Sequential(
|
||||||
nn.Mish(),
|
nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1))
|
||||||
nn.Linear(cond_dim, cond_channels),
|
|
||||||
nn.Unflatten(-1, (-1, 1))
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure dimensions compatible
|
# make sure dimensions compatible
|
||||||
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
|
self.residual_conv = (
|
||||||
if in_channels != out_channels else nn.Identity()
|
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, cond):
|
def forward(self, x, cond):
|
||||||
'''
|
"""
|
||||||
x : [ batch_size x in_channels x horizon ]
|
x : [ batch_size x in_channels x horizon ]
|
||||||
cond : [ batch_size x cond_dim]
|
cond : [ batch_size x cond_dim]
|
||||||
|
|
||||||
returns:
|
returns:
|
||||||
out : [ batch_size x out_channels x horizon ]
|
out : [ batch_size x out_channels x horizon ]
|
||||||
'''
|
"""
|
||||||
out = self.blocks[0](x)
|
out = self.blocks[0](x)
|
||||||
embed = self.cond_encoder(cond)
|
embed = self.cond_encoder(cond)
|
||||||
|
|
||||||
embed = embed.reshape(
|
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
|
||||||
embed.shape[0], 2, self.out_channels, 1)
|
|
||||||
scale = embed[:, 0, ...]
|
scale = embed[:, 0, ...]
|
||||||
bias = embed[:, 1, ...]
|
bias = embed[:, 1, ...]
|
||||||
out = scale * out + bias
|
out = scale * out + bias
|
||||||
|
@ -125,16 +120,15 @@ class ConditionalUnet1D(PreTrainedModel):
|
||||||
_no_split_modules = ["mid_modules", "down_modules", "up_modules"]
|
_no_split_modules = ["mid_modules", "down_modules", "up_modules"]
|
||||||
|
|
||||||
config_class = UnetDiffusionPolicyConfig
|
config_class = UnetDiffusionPolicyConfig
|
||||||
def __init__(self,
|
|
||||||
config: UnetDiffusionPolicyConfig
|
def __init__(self, config: UnetDiffusionPolicyConfig):
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
input_dim: Dim of actions.
|
input_dim: Dim of actions.
|
||||||
global_cond_dim: Dim of global conditioning applied with FiLM
|
global_cond_dim: Dim of global conditioning applied with FiLM
|
||||||
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
|
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
|
||||||
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
|
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
|
||||||
down_dims: Channel size for each UNet level.
|
down_dims: Channel size for each UNet level.
|
||||||
The length of this array determines numebr of levels.
|
The length of this array determines number of levels.
|
||||||
kernel_size: Conv kernel size
|
kernel_size: Conv kernel size
|
||||||
n_groups: Number of groups for GroupNorm
|
n_groups: Number of groups for GroupNorm
|
||||||
"""
|
"""
|
||||||
|
@ -158,44 +152,76 @@ class ConditionalUnet1D(PreTrainedModel):
|
||||||
)
|
)
|
||||||
cond_dim = dsed + config.global_cond_dim
|
cond_dim = dsed + config.global_cond_dim
|
||||||
|
|
||||||
in_out = list(zip(all_dims[:-1], all_dims[1:]))
|
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
|
||||||
mid_dim = all_dims[-1]
|
mid_dim = all_dims[-1]
|
||||||
self.mid_modules = nn.ModuleList([
|
self.mid_modules = nn.ModuleList(
|
||||||
|
[
|
||||||
ConditionalResidualBlock1D(
|
ConditionalResidualBlock1D(
|
||||||
mid_dim, mid_dim, cond_dim=cond_dim,
|
mid_dim,
|
||||||
kernel_size=config.kernel_size, n_groups=config.n_groups
|
mid_dim,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
n_groups=config.n_groups,
|
||||||
),
|
),
|
||||||
ConditionalResidualBlock1D(
|
ConditionalResidualBlock1D(
|
||||||
mid_dim, mid_dim, cond_dim=cond_dim,
|
mid_dim,
|
||||||
kernel_size=config.kernel_size, n_groups=config.n_groups
|
mid_dim,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
n_groups=config.n_groups,
|
||||||
),
|
),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
down_modules = nn.ModuleList([])
|
down_modules = nn.ModuleList([])
|
||||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||||
is_last = ind >= (len(in_out) - 1)
|
is_last = ind >= (len(in_out) - 1)
|
||||||
down_modules.append(nn.ModuleList([
|
down_modules.append(
|
||||||
|
nn.ModuleList(
|
||||||
|
[
|
||||||
ConditionalResidualBlock1D(
|
ConditionalResidualBlock1D(
|
||||||
dim_in, dim_out, cond_dim=cond_dim,
|
dim_in,
|
||||||
kernel_size=config.kernel_size, n_groups=config.n_groups),
|
dim_out,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
n_groups=config.n_groups,
|
||||||
|
),
|
||||||
ConditionalResidualBlock1D(
|
ConditionalResidualBlock1D(
|
||||||
dim_out, dim_out, cond_dim=cond_dim,
|
dim_out,
|
||||||
kernel_size=config.kernel_size, n_groups=config.n_groups),
|
dim_out,
|
||||||
Downsample1d(dim_out) if not is_last else nn.Identity()
|
cond_dim=cond_dim,
|
||||||
]))
|
kernel_size=config.kernel_size,
|
||||||
|
n_groups=config.n_groups,
|
||||||
|
),
|
||||||
|
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
up_modules = nn.ModuleList([])
|
up_modules = nn.ModuleList([])
|
||||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||||
is_last = ind >= (len(in_out) - 1)
|
is_last = ind >= (len(in_out) - 1)
|
||||||
up_modules.append(nn.ModuleList([
|
up_modules.append(
|
||||||
|
nn.ModuleList(
|
||||||
|
[
|
||||||
ConditionalResidualBlock1D(
|
ConditionalResidualBlock1D(
|
||||||
dim_out * 2, dim_in, cond_dim=cond_dim,
|
dim_out * 2,
|
||||||
kernel_size=config.kernel_size, n_groups=config.n_groups),
|
dim_in,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
n_groups=config.n_groups,
|
||||||
|
),
|
||||||
ConditionalResidualBlock1D(
|
ConditionalResidualBlock1D(
|
||||||
dim_in, dim_in, cond_dim=cond_dim,
|
dim_in,
|
||||||
kernel_size=config.kernel_size, n_groups=config.n_groups),
|
dim_in,
|
||||||
Upsample1d(dim_in) if not is_last else nn.Identity()
|
cond_dim=cond_dim,
|
||||||
]))
|
kernel_size=config.kernel_size,
|
||||||
|
n_groups=config.n_groups,
|
||||||
|
),
|
||||||
|
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
final_conv = nn.Sequential(
|
final_conv = nn.Sequential(
|
||||||
Conv1dBlock(start_dim, start_dim, kernel_size=config.kernel_size),
|
Conv1dBlock(start_dim, start_dim, kernel_size=config.kernel_size),
|
||||||
|
@ -207,20 +233,17 @@ class ConditionalUnet1D(PreTrainedModel):
|
||||||
self.down_modules = down_modules
|
self.down_modules = down_modules
|
||||||
self.final_conv = final_conv
|
self.final_conv = final_conv
|
||||||
|
|
||||||
print("number of parameters: {:e}".format(
|
print("number of parameters: {:e}".format(sum(p.numel() for p in self.parameters())))
|
||||||
sum(p.numel() for p in self.parameters()))
|
|
||||||
)
|
|
||||||
|
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
|
||||||
self.num_inference_timesteps = config.num_inference_timesteps
|
self.num_inference_timesteps = config.num_inference_timesteps
|
||||||
# self.proj_to_action = nn.Identity()
|
# self.proj_to_action = nn.Identity()
|
||||||
self.noise_scheduler = DDIMScheduler(
|
self.noise_scheduler = DDIMScheduler(
|
||||||
num_train_timesteps=config.num_train_timesteps, # 100
|
num_train_timesteps=config.num_train_timesteps, # 100
|
||||||
beta_schedule='squaredcos_cap_v2',
|
beta_schedule="squaredcos_cap_v2",
|
||||||
clip_sample=True,
|
clip_sample=True,
|
||||||
set_alpha_to_one=True,
|
set_alpha_to_one=True,
|
||||||
steps_offset=0,
|
steps_offset=0,
|
||||||
prediction_type='epsilon'
|
prediction_type="epsilon",
|
||||||
)
|
)
|
||||||
|
|
||||||
# self.num_inference_timesteps = config.num_inference_timesteps # 100
|
# self.num_inference_timesteps = config.num_inference_timesteps # 100
|
||||||
|
@ -228,32 +251,33 @@ class ConditionalUnet1D(PreTrainedModel):
|
||||||
def forward(self, actions, hidden_states, states, is_pad):
|
def forward(self, actions, hidden_states, states, is_pad):
|
||||||
"""
|
"""
|
||||||
Forward pass for the diffusion head.
|
Forward pass for the diffusion head.
|
||||||
:param actions: target actions, shape [B, Ta, D] D:10 = 3+6+1
|
:param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1
|
||||||
:param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [B,Tokens, D] 8 1200 1024
|
:param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [b,Tokens, D] 8 1200 1024
|
||||||
:param states: robot states, shape [B, D]
|
:param states: robot states, shape [b, D]
|
||||||
:return: loss
|
:return: loss
|
||||||
"""
|
"""
|
||||||
if actions is not None: # training time
|
if actions is not None: # training time
|
||||||
B = actions.size(0)
|
b = actions.size(0)
|
||||||
actions = copy.deepcopy(actions[:, : self.num_queries])
|
actions = copy.deepcopy(actions[:, : self.num_queries])
|
||||||
is_pad = copy.deepcopy(is_pad[:, : self.num_queries])
|
is_pad = copy.deepcopy(is_pad[:, : self.num_queries])
|
||||||
num_noise_samples = self.noise_samples
|
num_noise_samples = self.noise_samples
|
||||||
# sample noise to add to actions
|
# sample noise to add to actions
|
||||||
noise = torch.randn([num_noise_samples] + list(actions.shape), device=actions.device,
|
noise = torch.randn(
|
||||||
dtype=actions.dtype) # num_noise, B, Ta, D
|
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
|
||||||
|
) # num_noise, b, Ta, D
|
||||||
# sample a diffusion iteration for each data point
|
# sample a diffusion iteration for each data point
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
0, self.noise_scheduler.config.num_train_timesteps,
|
0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device
|
||||||
(B,), device=actions.device
|
|
||||||
).long()
|
).long()
|
||||||
|
|
||||||
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
|
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
|
||||||
|
|
||||||
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
|
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
|
||||||
# (this is the forward diffusion process)
|
# (this is the forward diffusion process)
|
||||||
noisy_actions = torch.cat([self.noise_scheduler.add_noise(
|
noisy_actions = torch.cat(
|
||||||
actions, noise[i], timesteps)
|
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
|
||||||
for i in range(len(noise))], dim=0) # [num_noise_samples * B, Ta, action_dim]
|
dim=0,
|
||||||
|
) # [num_noise_samples * b, Ta, action_dim]
|
||||||
|
|
||||||
noisy_actions = noisy_actions.to(dtype=actions.dtype)
|
noisy_actions = noisy_actions.to(dtype=actions.dtype)
|
||||||
assert hidden_states.ndim == 3
|
assert hidden_states.ndim == 3
|
||||||
|
@ -263,20 +287,22 @@ class ConditionalUnet1D(PreTrainedModel):
|
||||||
is_pad = is_pad.repeat(num_noise_samples, 1)
|
is_pad = is_pad.repeat(num_noise_samples, 1)
|
||||||
states = states.repeat(num_noise_samples, 1)
|
states = states.repeat(num_noise_samples, 1)
|
||||||
|
|
||||||
noise_pred = self.model_forward(noisy_actions, timesteps, global_cond=hidden_states, states=states)
|
noise_pred = self.model_forward(
|
||||||
|
noisy_actions, timesteps, global_cond=hidden_states, states=states
|
||||||
|
)
|
||||||
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
|
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction='none')
|
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none")
|
||||||
loss = (loss * ~is_pad.unsqueeze(-1)).mean()
|
loss = (loss * ~is_pad.unsqueeze(-1)).mean()
|
||||||
# loss_dict['loss'] = loss
|
# loss_dict['loss'] = loss
|
||||||
return {'loss': loss}
|
return {"loss": loss}
|
||||||
# return loss
|
# return loss
|
||||||
else: # inference time
|
else: # inference time
|
||||||
B = 1
|
b = 1
|
||||||
Tp = self.num_queries
|
tp = self.num_queries
|
||||||
action_dim = 14
|
action_dim = 14
|
||||||
|
|
||||||
# initialize action from Guassian noise
|
# initialize action from Gaussian noise
|
||||||
noisy_action = torch.randn((B, Tp, action_dim)).cuda()
|
noisy_action = torch.randn((b, tp, action_dim)).cuda()
|
||||||
|
|
||||||
naction = noisy_action.to(dtype=hidden_states.dtype)
|
naction = noisy_action.to(dtype=hidden_states.dtype)
|
||||||
# init scheduler
|
# init scheduler
|
||||||
|
@ -288,27 +314,23 @@ class ConditionalUnet1D(PreTrainedModel):
|
||||||
|
|
||||||
# inverse diffusion step (remove noise)
|
# inverse diffusion step (remove noise)
|
||||||
naction = self.noise_scheduler.step(
|
naction = self.noise_scheduler.step(
|
||||||
model_output=noise_pred,
|
model_output=noise_pred, timestep=k, sample=naction
|
||||||
timestep=k,
|
|
||||||
sample=naction
|
|
||||||
).prev_sample
|
).prev_sample
|
||||||
|
|
||||||
return naction
|
return naction
|
||||||
|
|
||||||
def model_forward(self,
|
def model_forward(
|
||||||
sample: torch.Tensor,
|
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], global_cond=None, states=None
|
||||||
timestep: Union[torch.Tensor, float, int],
|
):
|
||||||
global_cond=None,
|
|
||||||
states=None):
|
|
||||||
"""
|
"""
|
||||||
x: (B,T,input_dim)
|
x: (b,T,input_dim)
|
||||||
timestep: (B,) or int, diffusion step
|
timestep: (b,) or int, diffusion step
|
||||||
global_cond: (B,global_cond_dim)
|
global_cond: (b,global_cond_dim)
|
||||||
output: (B,T,input_dim)
|
output: (b,T,input_dim)
|
||||||
"""
|
"""
|
||||||
# (B,T,C)
|
# (b,t,c)
|
||||||
sample = sample.moveaxis(-1, -2)
|
sample = sample.moveaxis(-1, -2)
|
||||||
# (B,C,T)
|
# (b,c,t)
|
||||||
# global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1)
|
# global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1)
|
||||||
global_cond = global_cond.squeeze(1)
|
global_cond = global_cond.squeeze(1)
|
||||||
|
|
||||||
|
@ -327,13 +349,11 @@ class ConditionalUnet1D(PreTrainedModel):
|
||||||
global_feature = self.diffusion_step_encoder(timesteps)
|
global_feature = self.diffusion_step_encoder(timesteps)
|
||||||
|
|
||||||
if global_cond is not None:
|
if global_cond is not None:
|
||||||
global_feature = torch.cat([
|
global_feature = torch.cat([global_feature, global_cond], axis=-1)
|
||||||
global_feature, global_cond
|
|
||||||
], axis=-1)
|
|
||||||
|
|
||||||
x = sample
|
x = sample
|
||||||
h = []
|
h = []
|
||||||
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
for _idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||||
x = resnet(x, global_feature)
|
x = resnet(x, global_feature)
|
||||||
x = resnet2(x, global_feature)
|
x = resnet2(x, global_feature)
|
||||||
h.append(x)
|
h.append(x)
|
||||||
|
@ -342,7 +362,7 @@ class ConditionalUnet1D(PreTrainedModel):
|
||||||
for mid_module in self.mid_modules:
|
for mid_module in self.mid_modules:
|
||||||
x = mid_module(x, global_feature)
|
x = mid_module(x, global_feature)
|
||||||
|
|
||||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
for _idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||||
x = torch.cat((x, h.pop()), dim=1)
|
x = torch.cat((x, h.pop()), dim=1)
|
||||||
x = resnet(x, global_feature)
|
x = resnet(x, global_feature)
|
||||||
x = resnet2(x, global_feature)
|
x = resnet2(x, global_feature)
|
||||||
|
@ -350,9 +370,10 @@ class ConditionalUnet1D(PreTrainedModel):
|
||||||
|
|
||||||
x = self.final_conv(x)
|
x = self.final_conv(x)
|
||||||
|
|
||||||
# (B,C,T)
|
# (b,c,t)
|
||||||
x = x.moveaxis(-1, -2)
|
x = x.moveaxis(-1, -2)
|
||||||
# (B,T,C)
|
# (b,t,c)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D)
|
AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D)
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -17,10 +16,10 @@
|
||||||
import os
|
import os
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
from transformers import AutoConfig
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.modeling_rope_utils import rope_config_validation
|
from transformers.modeling_rope_utils import rope_config_validation
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
from transformers import AutoModel, AutoConfig
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
@ -56,7 +55,9 @@ class Qwen2VLVisionConfig(PretrainedConfig):
|
||||||
self.temporal_patch_size = temporal_patch_size
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
def from_pretrained(
|
||||||
|
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||||
|
) -> "PretrainedConfig":
|
||||||
cls._set_token_in_kwargs(kwargs)
|
cls._set_token_in_kwargs(kwargs)
|
||||||
|
|
||||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||||
|
@ -64,7 +65,11 @@ class Qwen2VLVisionConfig(PretrainedConfig):
|
||||||
if config_dict.get("model_type") == "qwen2_vl":
|
if config_dict.get("model_type") == "qwen2_vl":
|
||||||
config_dict = config_dict["vision_config"]
|
config_dict = config_dict["vision_config"]
|
||||||
|
|
||||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
if (
|
||||||
|
"model_type" in config_dict
|
||||||
|
and hasattr(cls, "model_type")
|
||||||
|
and config_dict["model_type"] != cls.model_type
|
||||||
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||||
|
@ -204,7 +209,7 @@ class Qwen2VLAConfig(PretrainedConfig):
|
||||||
vision_config=None,
|
vision_config=None,
|
||||||
rope_scaling=None,
|
rope_scaling=None,
|
||||||
# For loading policy head
|
# For loading policy head
|
||||||
policy_head_type='scale_dp_policy', # unet_diffusion_policy
|
policy_head_type="scale_dp_policy", # unet_diffusion_policy
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if isinstance(vision_config, dict):
|
if isinstance(vision_config, dict):
|
||||||
|
@ -238,7 +243,7 @@ class Qwen2VLAConfig(PretrainedConfig):
|
||||||
|
|
||||||
# Validate the correctness of rotary position embeddings parameters
|
# Validate the correctness of rotary position embeddings parameters
|
||||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||||
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
|
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
|
||||||
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
||||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||||
if self.rope_scaling["type"] == "mrope":
|
if self.rope_scaling["type"] == "mrope":
|
||||||
|
@ -248,5 +253,5 @@ class Qwen2VLAConfig(PretrainedConfig):
|
||||||
|
|
||||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||||
|
|
||||||
from transformers import AutoConfig
|
|
||||||
AutoConfig.register("qwen2_vla", Qwen2VLAConfig)
|
AutoConfig.register("qwen2_vla", Qwen2VLAConfig)
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
@ -19,16 +18,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch Qwen2-VL model."""
|
"""PyTorch Qwen2-VL model."""
|
||||||
|
|
||||||
|
import gc
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as func
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.nn import CrossEntropyLoss, LayerNorm
|
from torch.nn import CrossEntropyLoss, LayerNorm
|
||||||
|
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache
|
from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache
|
||||||
from transformers.generation import GenerationMixin
|
from transformers.generation import GenerationMixin
|
||||||
|
@ -37,8 +37,6 @@ from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
)
|
)
|
||||||
from lerobot.common.policies.dexvla.fusion_modules import *
|
|
||||||
|
|
||||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
|
@ -49,14 +47,13 @@ from transformers.utils import (
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig
|
|
||||||
from transformers import AutoConfig, AutoModel
|
|
||||||
import gc
|
|
||||||
|
|
||||||
|
from lerobot.common.policies.dexvla.fusion_modules import ActionProjector, FiLM
|
||||||
|
|
||||||
|
from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
else:
|
else:
|
||||||
flash_attn_varlen_func = None
|
flash_attn_varlen_func = None
|
||||||
|
@ -164,7 +161,9 @@ class Qwen2VLRotaryEmbedding(nn.Module):
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if (
|
||||||
|
seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len
|
||||||
|
): # reset
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
@ -173,7 +172,7 @@ class Qwen2VLRotaryEmbedding(nn.Module):
|
||||||
if "dynamic" in self.rope_type:
|
if "dynamic" in self.rope_type:
|
||||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||||
|
|
||||||
# Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids
|
# Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for spatiotemporal grids
|
||||||
# So we expand the inv_freq to shape (3, ...)
|
# So we expand the inv_freq to shape (3, ...)
|
||||||
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
||||||
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
||||||
|
@ -207,7 +206,7 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
|
||||||
Explanation:
|
Explanation:
|
||||||
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
|
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
|
||||||
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
|
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
|
||||||
vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately.
|
vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
|
||||||
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
|
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
|
||||||
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
|
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
|
||||||
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
|
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
|
||||||
|
@ -335,7 +334,9 @@ class VisionAttention(nn.Module):
|
||||||
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
seq_length = hidden_states.shape[0]
|
seq_length = hidden_states.shape[0]
|
||||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
q, k, v = (
|
||||||
|
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||||
|
)
|
||||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
|
|
||||||
|
@ -369,7 +370,9 @@ class VisionFlashAttention2(nn.Module):
|
||||||
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
seq_length = hidden_states.shape[0]
|
seq_length = hidden_states.shape[0]
|
||||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
q, k, v = (
|
||||||
|
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||||
|
)
|
||||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
|
|
||||||
|
@ -392,7 +395,9 @@ class VisionSdpaAttention(nn.Module):
|
||||||
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
seq_length = hidden_states.shape[0]
|
seq_length = hidden_states.shape[0]
|
||||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
q, k, v = (
|
||||||
|
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||||
|
)
|
||||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
|
|
||||||
|
@ -402,7 +407,7 @@ class VisionSdpaAttention(nn.Module):
|
||||||
q = q.transpose(0, 1)
|
q = q.transpose(0, 1)
|
||||||
k = k.transpose(0, 1)
|
k = k.transpose(0, 1)
|
||||||
v = v.transpose(0, 1)
|
v = v.transpose(0, 1)
|
||||||
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
attn_output = func.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
||||||
attn_output = attn_output.transpose(0, 1)
|
attn_output = attn_output.transpose(0, 1)
|
||||||
attn_output = attn_output.reshape(seq_length, -1)
|
attn_output = attn_output.reshape(seq_length, -1)
|
||||||
attn_output = self.proj(attn_output)
|
attn_output = self.proj(attn_output)
|
||||||
|
@ -538,7 +543,9 @@ class Qwen2VLAttention(nn.Module):
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
] = None, # will become mandatory in v4.46
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
@ -569,8 +576,14 @@ class Qwen2VLAttention(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
cache_kwargs = {
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
"sin": sin,
|
||||||
|
"cos": cos,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
@ -585,7 +598,9 @@ class Qwen2VLAttention(nn.Module):
|
||||||
# Fix precision issues in Qwen2-VL float16 inference
|
# Fix precision issues in Qwen2-VL float16 inference
|
||||||
# Replace inf values with zeros in attention weights to prevent NaN propagation
|
# Replace inf values with zeros in attention weights to prevent NaN propagation
|
||||||
if query_states.dtype == torch.float16:
|
if query_states.dtype == torch.float16:
|
||||||
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
|
attn_weights = torch.where(
|
||||||
|
torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights
|
||||||
|
)
|
||||||
|
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
@ -621,7 +636,7 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
@ -634,7 +649,9 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
] = None, # will become mandatory in v4.46
|
||||||
):
|
):
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
@ -696,10 +713,18 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask[:, slicing_tokens:]
|
attention_mask = attention_mask[:, slicing_tokens:]
|
||||||
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
attention_mask = torch.cat(
|
||||||
|
[attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
cache_kwargs = {
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
"sin": sin,
|
||||||
|
"cos": cos,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
@ -781,7 +806,9 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
] = None, # will become mandatory in v4.46
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
|
@ -826,8 +853,14 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
|
||||||
)
|
)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
cache_kwargs = {
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
"sin": sin,
|
||||||
|
"cos": cos,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
@ -846,7 +879,7 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
|
||||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
is_causal = bool(causal_mask is None and q_len > 1)
|
||||||
|
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_states,
|
query_states,
|
||||||
|
@ -897,7 +930,9 @@ class Qwen2VLDecoderLayer(nn.Module):
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
] = None, # will become mandatory in v4.46
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -1031,9 +1066,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||||
def get_device(self) -> torch.device:
|
def get_device(self) -> torch.device:
|
||||||
return self.blocks[0].mlp.fc2.weight.device
|
return self.blocks[0].mlp.fc2.weight.device
|
||||||
|
|
||||||
def rot_pos_emb(self, grid_thw):
|
def rot_pos_emb(self, grid_spatiotemporal):
|
||||||
pos_ids = []
|
pos_ids = []
|
||||||
for t, h, w in grid_thw:
|
for t, h, w in grid_spatiotemporal:
|
||||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||||
hpos_ids = hpos_ids.reshape(
|
hpos_ids = hpos_ids.reshape(
|
||||||
h // self.spatial_merge_size,
|
h // self.spatial_merge_size,
|
||||||
|
@ -1055,19 +1090,19 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||||
wpos_ids = wpos_ids.flatten()
|
wpos_ids = wpos_ids.flatten()
|
||||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||||
pos_ids = torch.cat(pos_ids, dim=0)
|
pos_ids = torch.cat(pos_ids, dim=0)
|
||||||
max_grid_size = grid_thw[:, 1:].max()
|
max_grid_size = grid_spatiotemporal[:, 1:].max()
|
||||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
return rotary_pos_emb
|
return rotary_pos_emb
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor, grid_spatiotemporal: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.patch_embed(hidden_states)
|
hidden_states = self.patch_embed(hidden_states)
|
||||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
rotary_pos_emb = self.rot_pos_emb(grid_spatiotemporal)
|
||||||
|
|
||||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
cu_seqlens = torch.repeat_interleave(
|
||||||
dim=0, dtype=torch.int32
|
grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0]
|
||||||
)
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
cu_seqlens = func.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
||||||
|
@ -1116,7 +1151,9 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = (
|
||||||
|
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
)
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
|
@ -1127,8 +1164,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
)
|
)
|
||||||
|
@ -1208,7 +1244,9 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(
|
||||||
|
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None
|
||||||
|
)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=next_cache,
|
past_key_values=next_cache,
|
||||||
|
@ -1242,13 +1280,13 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||||
self.config._attn_implementation == "sdpa"
|
self.config._attn_implementation == "sdpa"
|
||||||
and not (using_static_cache or using_sliding_window_cache)
|
and not (using_static_cache or using_sliding_window_cache)
|
||||||
and not output_attentions
|
and not output_attentions
|
||||||
):
|
and AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
|
||||||
attention_mask,
|
attention_mask,
|
||||||
inputs_embeds=input_tensor,
|
inputs_embeds=input_tensor,
|
||||||
past_key_values_length=past_seen_tokens,
|
past_key_values_length=past_seen_tokens,
|
||||||
sliding_window=self.config.sliding_window,
|
sliding_window=self.config.sliding_window,
|
||||||
is_training=self.training,
|
is_training=self.training,
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -1319,7 +1357,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||||
dtype (`torch.dtype`):
|
dtype (`torch.dtype`):
|
||||||
The dtype to use for the 4D attention mask.
|
The dtype to use for the 4D attention mask.
|
||||||
device (`torch.device`):
|
device (`torch.device`):
|
||||||
The device to plcae the 4D attention mask on.
|
The device to place the 4D attention mask on.
|
||||||
cache_position (`torch.Tensor`):
|
cache_position (`torch.Tensor`):
|
||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
batch_size (`torch.Tensor`):
|
batch_size (`torch.Tensor`):
|
||||||
|
@ -1338,10 +1376,11 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
if config.sliding_window is not None:
|
if config.sliding_window is not None and (
|
||||||
|
not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length
|
||||||
|
):
|
||||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
|
||||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||||
cache_position.reshape(-1, 1) - config.sliding_window
|
cache_position.reshape(-1, 1) - config.sliding_window
|
||||||
)
|
)
|
||||||
|
@ -1428,14 +1467,15 @@ QWEN2_VL_INPUTS_DOCSTRING = r"""
|
||||||
The tensors corresponding to the input videos. Pixel values can be obtained using
|
The tensors corresponding to the input videos. Pixel values can be obtained using
|
||||||
[`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
|
[`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
|
||||||
[`Qwen2VLImageProcessor`] for processing videos.
|
[`Qwen2VLImageProcessor`] for processing videos.
|
||||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
image_grid_spatiotemporal (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||||
The temporal, height and width of feature shape of each image in LLM.
|
The temporal, height and width of feature shape of each image in LLM.
|
||||||
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
video_grid_spatiotemporal (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||||
The temporal, height and width of feature shape of each video in LLM.
|
The temporal, height and width of feature shape of each video in LLM.
|
||||||
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
||||||
The rope index difference between sequence length and multimodal rope.
|
The rope index difference between sequence length and multimodal rope.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMixin):
|
class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
|
|
||||||
|
@ -1491,8 +1531,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
def get_rope_index(
|
def get_rope_index(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_spatiotemporal: Optional[torch.LongTensor] = None,
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
video_grid_spatiotemporal: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
|
@ -1501,7 +1541,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
Explanation:
|
Explanation:
|
||||||
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
||||||
|
|
||||||
For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
|
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
|
||||||
Examples:
|
Examples:
|
||||||
input_ids: [T T T T T], here T is for text.
|
input_ids: [T T T T T], here T is for text.
|
||||||
temporal position_ids: [0, 1, 2, 3, 4]
|
temporal position_ids: [0, 1, 2, 3, 4]
|
||||||
|
@ -1525,9 +1565,9 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||||
it.
|
it.
|
||||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
image_grid_spatiotemporal (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||||
The temporal, height and width of feature shape of each image in LLM.
|
The temporal, height and width of feature shape of each image in LLM.
|
||||||
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
video_grid_spatiotemporal (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||||
The temporal, height and width of feature shape of each video in LLM.
|
The temporal, height and width of feature shape of each video in LLM.
|
||||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
@ -1544,7 +1584,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
video_token_id = self.config.video_token_id
|
video_token_id = self.config.video_token_id
|
||||||
vision_start_token_id = self.config.vision_start_token_id
|
vision_start_token_id = self.config.vision_start_token_id
|
||||||
mrope_position_deltas = []
|
mrope_position_deltas = []
|
||||||
if image_grid_thw is not None or video_grid_thw is not None:
|
if image_grid_spatiotemporal is not None or video_grid_spatiotemporal is not None:
|
||||||
total_input_ids = input_ids
|
total_input_ids = input_ids
|
||||||
position_ids = torch.ones(
|
position_ids = torch.ones(
|
||||||
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
|
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
|
||||||
|
@ -1573,18 +1613,18 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
ed_video = len(input_tokens) + 1
|
ed_video = len(input_tokens) + 1
|
||||||
if ed_image < ed_video:
|
if ed_image < ed_video:
|
||||||
t, h, w = (
|
t, h, w = (
|
||||||
image_grid_thw[image_index][0],
|
image_grid_spatiotemporal[image_index][0],
|
||||||
image_grid_thw[image_index][1],
|
image_grid_spatiotemporal[image_index][1],
|
||||||
image_grid_thw[image_index][2],
|
image_grid_spatiotemporal[image_index][2],
|
||||||
)
|
)
|
||||||
image_index += 1
|
image_index += 1
|
||||||
remain_images -= 1
|
remain_images -= 1
|
||||||
ed = ed_image
|
ed = ed_image
|
||||||
else:
|
else:
|
||||||
t, h, w = (
|
t, h, w = (
|
||||||
video_grid_thw[video_index][0],
|
video_grid_spatiotemporal[video_index][0],
|
||||||
video_grid_thw[video_index][1],
|
video_grid_spatiotemporal[video_index][1],
|
||||||
video_grid_thw[video_index][2],
|
video_grid_spatiotemporal[video_index][2],
|
||||||
)
|
)
|
||||||
video_index += 1
|
video_index += 1
|
||||||
remain_videos -= 1
|
remain_videos -= 1
|
||||||
|
@ -1599,9 +1639,15 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||||
|
|
||||||
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
t_index = (
|
||||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
||||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
)
|
||||||
|
h_index = (
|
||||||
|
torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
||||||
|
)
|
||||||
|
w_index = (
|
||||||
|
torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
||||||
|
)
|
||||||
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
||||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||||
|
|
||||||
|
@ -1671,8 +1717,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
pixel_values: Optional[torch.Tensor] = None,
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_spatiotemporal: Optional[torch.LongTensor] = None,
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
video_grid_spatiotemporal: Optional[torch.LongTensor] = None,
|
||||||
rope_deltas: Optional[torch.LongTensor] = None,
|
rope_deltas: Optional[torch.LongTensor] = None,
|
||||||
actions: Optional[torch.LongTensor] = None,
|
actions: Optional[torch.LongTensor] = None,
|
||||||
states: Optional[torch.FloatTensor] = None,
|
states: Optional[torch.FloatTensor] = None,
|
||||||
|
@ -1725,14 +1771,16 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
attention_mask = attention_mask.to("cuda")
|
attention_mask = attention_mask.to("cuda")
|
||||||
if not is_eval:
|
if not is_eval:
|
||||||
labels = labels.to("cuda")
|
labels = labels.to("cuda")
|
||||||
actions = actions.to(dtype=self.computed_type, device='cuda')
|
actions = actions.to(dtype=self.computed_type, device="cuda")
|
||||||
states = states.to(dtype=self.computed_type, device='cuda')
|
states = states.to(dtype=self.computed_type, device="cuda")
|
||||||
position_ids, rope_deltas = self.get_rope_index(
|
position_ids, rope_deltas = self.get_rope_index(
|
||||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
input_ids, image_grid_spatiotemporal, video_grid_spatiotemporal, attention_mask
|
||||||
)
|
)
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = pixel_values.to(dtype=self.computed_type, device='cuda')
|
pixel_values = pixel_values.to(dtype=self.computed_type, device="cuda")
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = (
|
||||||
|
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
)
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
|
@ -1742,7 +1790,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = pixel_values.type(self.visual.get_dtype())
|
pixel_values = pixel_values.type(self.visual.get_dtype())
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
image_embeds = self.visual(pixel_values, grid_spatiotemporal=image_grid_spatiotemporal)
|
||||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||||
n_image_features = image_embeds.shape[0]
|
n_image_features = image_embeds.shape[0]
|
||||||
if n_image_tokens != n_image_features:
|
if n_image_tokens != n_image_features:
|
||||||
|
@ -1760,7 +1808,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
|
|
||||||
if pixel_values_videos is not None:
|
if pixel_values_videos is not None:
|
||||||
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
video_embeds = self.visual(pixel_values_videos, grid_spatiotemporal=video_grid_spatiotemporal)
|
||||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||||
n_video_features = video_embeds.shape[0]
|
n_video_features = video_embeds.shape[0]
|
||||||
if n_video_tokens != n_video_features:
|
if n_video_tokens != n_video_features:
|
||||||
|
@ -1833,21 +1881,28 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.using_film:
|
if self.using_film:
|
||||||
action_hidden_states = self.film_forward(labels=labels, input_ids=input_ids,
|
action_hidden_states = self.film_forward(
|
||||||
hidden_states=hidden_states)
|
labels=labels, input_ids=input_ids, hidden_states=hidden_states
|
||||||
|
)
|
||||||
else: # tinyvla
|
else: # tinyvla
|
||||||
action_hidden_states = hidden_states
|
action_hidden_states = hidden_states
|
||||||
|
|
||||||
ret = self.policy_head(actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad)
|
ret = self.policy_head(
|
||||||
|
actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad
|
||||||
|
)
|
||||||
|
|
||||||
if self.with_llm_head:
|
if self.with_llm_head:
|
||||||
loss = {'loss': ret['loss'] + self.llm_loss_weight * llm_loss,
|
loss = {
|
||||||
'llm_loss': llm_loss,
|
"loss": ret["loss"] + self.llm_loss_weight * llm_loss,
|
||||||
'action_loss': ret['loss']}
|
"llm_loss": llm_loss,
|
||||||
|
"action_loss": ret["loss"],
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
loss = {'loss': ret['loss'],
|
loss = {
|
||||||
'llm_loss': (torch.ones(1)*(-100)).to(ret['loss'].dtype).squeeze(0),
|
"loss": ret["loss"],
|
||||||
'action_loss': ret['loss']}
|
"llm_loss": (torch.ones(1) * (-100)).to(ret["loss"].dtype).squeeze(0),
|
||||||
|
"action_loss": ret["loss"],
|
||||||
|
}
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
@ -1862,7 +1917,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
del inputs_embeds
|
del inputs_embeds
|
||||||
del labels
|
del labels
|
||||||
del pixel_values
|
del pixel_values
|
||||||
del image_grid_thw
|
del image_grid_spatiotemporal
|
||||||
del actions
|
del actions
|
||||||
del states
|
del states
|
||||||
return Qwen2VLCausalLMOutputWithPast(
|
return Qwen2VLCausalLMOutputWithPast(
|
||||||
|
@ -1882,12 +1937,12 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
inputs_index = inputs_index.int()
|
inputs_index = inputs_index.int()
|
||||||
|
|
||||||
xor_array = torch.bitwise_xor(inputs_index[:, :-1], inputs_index[:, 1:])
|
xor_array = torch.bitwise_xor(inputs_index[:, :-1], inputs_index[:, 1:])
|
||||||
indexs = torch.argmax((xor_array != 0).float(), dim=1)
|
indexes = torch.argmax((xor_array != 0).float(), dim=1)
|
||||||
input_embeddings = []
|
input_embeddings = []
|
||||||
reasoning_embeddings = []
|
reasoning_embeddings = []
|
||||||
identity = []
|
identity = []
|
||||||
for i in range(indexs.shape[0]):
|
for i in range(indexes.shape[0]):
|
||||||
end = indexs[i] + 1
|
end = indexes[i] + 1
|
||||||
temp = input_ids[i] == 151643 # pad token id for qwen2_vl
|
temp = input_ids[i] == 151643 # pad token id for qwen2_vl
|
||||||
start = sum(temp.int())
|
start = sum(temp.int())
|
||||||
input_embeddings.append(self.input_action_proj(hidden_states[i, start:end, :]))
|
input_embeddings.append(self.input_action_proj(hidden_states[i, start:end, :]))
|
||||||
|
@ -1914,8 +1969,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
pixel_values=None,
|
pixel_values=None,
|
||||||
pixel_values_videos=None,
|
pixel_values_videos=None,
|
||||||
image_grid_thw=None,
|
image_grid_spatiotemporal=None,
|
||||||
video_grid_thw=None,
|
video_grid_spatiotemporal=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
|
@ -1924,19 +1979,23 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
if inputs_embeds is not None: # Exception 1
|
if inputs_embeds is not None: # Exception 1
|
||||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
elif (
|
||||||
|
input_ids.shape[1] != cache_position.shape[0]
|
||||||
|
): # Default case (the "else", a no op, is Exception 2)
|
||||||
input_ids = input_ids[:, cache_position]
|
input_ids = input_ids[:, cache_position]
|
||||||
|
|
||||||
rope_deltas = kwargs.get("rope_deltas", None)
|
rope_deltas = kwargs.get("rope_deltas")
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
|
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
|
||||||
position_ids, rope_deltas = self.get_rope_index(
|
position_ids, rope_deltas = self.get_rope_index(
|
||||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
input_ids, image_grid_spatiotemporal, video_grid_spatiotemporal, attention_mask
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
delta = (
|
delta = (
|
||||||
cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0
|
cache_position[0] + rope_deltas
|
||||||
|
if cache_position is not None and rope_deltas is not None
|
||||||
|
else 0
|
||||||
)
|
)
|
||||||
position_ids = torch.arange(seq_length, device=input_ids.device)
|
position_ids = torch.arange(seq_length, device=input_ids.device)
|
||||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
||||||
|
@ -1981,8 +2040,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"pixel_values_videos": pixel_values_videos,
|
"pixel_values_videos": pixel_values_videos,
|
||||||
"image_grid_thw": image_grid_thw,
|
"image_grid_spatiotemporal": image_grid_spatiotemporal,
|
||||||
"video_grid_thw": video_grid_thw,
|
"video_grid_spatiotemporal": video_grid_spatiotemporal,
|
||||||
"rope_deltas": rope_deltas,
|
"rope_deltas": rope_deltas,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -1990,6 +2049,4 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)
|
AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torchvision.transforms.functional import to_pil_image, to_tensor
|
|
||||||
import torchvision.transforms as transforms
|
|
||||||
import torch
|
import torch
|
||||||
from qwen_vl_utils import process_vision_info
|
from PIL import Image
|
||||||
from qwen_vl_utils import fetch_image
|
from qwen_vl_utils import fetch_image
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLAProcess:
|
class Qwen2VLAProcess:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -20,10 +19,10 @@ class Qwen2VLAProcess:
|
||||||
def qwen2_image_preprocess(self, each):
|
def qwen2_image_preprocess(self, each):
|
||||||
ele = {}
|
ele = {}
|
||||||
each = Image.fromarray(each.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8))
|
each = Image.fromarray(each.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8))
|
||||||
ele['image'] = each
|
ele["image"] = each
|
||||||
|
|
||||||
ele['resized_height'] = each.height
|
ele["resized_height"] = each.height
|
||||||
ele['resized_width'] = each.width
|
ele["resized_width"] = each.width
|
||||||
each = fetch_image(ele)
|
each = fetch_image(ele)
|
||||||
return torch.from_numpy(np.array(each))
|
return torch.from_numpy(np.array(each))
|
||||||
|
|
||||||
|
@ -31,15 +30,13 @@ class Qwen2VLAProcess:
|
||||||
len_views = images.shape[0]
|
len_views = images.shape[0]
|
||||||
messages = self.construct_chat_data(len_views, raw_lang)
|
messages = self.construct_chat_data(len_views, raw_lang)
|
||||||
|
|
||||||
data_dict = dict(
|
data_dict = {"messages": messages}
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_data = torch.chunk(images, len_views, 0)
|
image_data = torch.chunk(images, len_views, 0)
|
||||||
|
|
||||||
images_list = []
|
images_list = []
|
||||||
|
|
||||||
for i, each in enumerate(image_data):
|
for _i, each in enumerate(image_data):
|
||||||
img_pil = self.qwen2_image_preprocess(each)
|
img_pil = self.qwen2_image_preprocess(each)
|
||||||
images_list.append(img_pil)
|
images_list.append(img_pil)
|
||||||
|
|
||||||
|
@ -58,75 +55,78 @@ class Qwen2VLAProcess:
|
||||||
if eval:
|
if eval:
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
input_labels = torch.ones_like(model_inputs['input_ids']) * -100
|
input_labels = torch.ones_like(model_inputs["input_ids"]) * -100
|
||||||
if use_reasoning:
|
answer = reasoning + "Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>"
|
||||||
answer =reasoning + "Next action:" + '<|im_end|>'
|
|
||||||
else:
|
|
||||||
answer = '' + '<|im_end|>'
|
|
||||||
|
|
||||||
output_text = self.tokenizer(answer, padding=True, return_tensors="pt")
|
output_text = self.tokenizer(answer, padding=True, return_tensors="pt")
|
||||||
output_labels = output_text['input_ids']
|
output_labels = output_text["input_ids"]
|
||||||
model_inputs['input_ids'] = torch.cat((model_inputs['input_ids'], output_text['input_ids']), dim=-1)
|
model_inputs["input_ids"] = torch.cat((model_inputs["input_ids"], output_text["input_ids"]), dim=-1)
|
||||||
model_inputs['attention_mask'] = torch.cat((model_inputs['attention_mask'], output_text['attention_mask']), dim=-1)
|
model_inputs["attention_mask"] = torch.cat(
|
||||||
|
(model_inputs["attention_mask"], output_text["attention_mask"]), dim=-1
|
||||||
|
)
|
||||||
labels = torch.cat((input_labels, output_labels), dim=-1)
|
labels = torch.cat((input_labels, output_labels), dim=-1)
|
||||||
|
|
||||||
data_dict['labels'] = labels
|
data_dict["labels"] = labels
|
||||||
for k, v in model_inputs.items():
|
for k, v in model_inputs.items():
|
||||||
data_dict[k] = v
|
data_dict[k] = v
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
def forward(self, batch, use_reasoning=True):
|
def forward(self, batch, use_reasoning=True):
|
||||||
"""This is the main process function for processing vl data into Qwen2_vl format"""
|
"""This is the main process function for processing vl data into Qwen2_vl format"""
|
||||||
all_images = batch['images']
|
all_images = batch["images"]
|
||||||
all_images = torch.einsum('v b c h w -> b v c h w', all_images) # camera_views, batch_size, channel, height, width
|
all_images = torch.einsum(
|
||||||
|
"v b c h w -> b v c h w", all_images
|
||||||
|
) # camera_views, batch_size, channel, height, width
|
||||||
|
|
||||||
ret_l = []
|
ret_l = []
|
||||||
|
|
||||||
for idx, images in enumerate(all_images):
|
for idx, images in enumerate(all_images):
|
||||||
raw_lang = batch['raw_langs'][idx]
|
raw_lang = batch["raw_langs"][idx]
|
||||||
reasoning = batch['reasonings'][idx]
|
reasoning = batch["reasonings"][idx]
|
||||||
ret_dict = self.single_forward_process(images, raw_lang, reasoning, use_reasoning=use_reasoning)
|
ret_dict = self.single_forward_process(images, raw_lang, reasoning, use_reasoning=use_reasoning)
|
||||||
ret_l.append(ret_dict)
|
ret_l.append(ret_dict)
|
||||||
|
|
||||||
return self.post_process(ret_l)
|
return self.post_process(ret_l)
|
||||||
|
|
||||||
def post_process(self, instances):
|
def post_process(self, instances):
|
||||||
input_ids = [torch.flip(instance['input_ids'].squeeze(0), dims=[0]) for instance in instances]
|
input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances]
|
||||||
labels = [torch.flip(instance['labels'].squeeze(0), dims=[0]) for instance in instances]
|
labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances]
|
||||||
|
|
||||||
image_grid_thw = torch.stack([instances['image_grid_thw'] for instances in instances])
|
image_grid_spatiotemporal = torch.stack(
|
||||||
pixel_values = torch.stack([instances['pixel_values'] for instances in instances])
|
[instances["image_grid_spatiotemporal"] for instances in instances]
|
||||||
|
)
|
||||||
|
pixel_values = torch.stack([instances["pixel_values"] for instances in instances])
|
||||||
pixel_values_videos = None
|
pixel_values_videos = None
|
||||||
video_grid_thw = None
|
video_grid_spatiotemporal = None
|
||||||
|
|
||||||
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
|
||||||
batch_first=True,
|
|
||||||
padding_value=-100)
|
|
||||||
labels = torch.flip(labels, dims=[1])
|
labels = torch.flip(labels, dims=[1])
|
||||||
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
batch_first=True,
|
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||||
padding_value=self.tokenizer.pad_token_id)
|
)
|
||||||
input_ids = torch.flip(input_ids, dims=[1])
|
input_ids = torch.flip(input_ids, dims=[1])
|
||||||
b = input_ids.shape[0]
|
b = input_ids.shape[0]
|
||||||
|
|
||||||
image_grid_thw = image_grid_thw.reshape(b * image_grid_thw.shape[1], image_grid_thw.shape[2])
|
image_grid_spatiotemporal = image_grid_spatiotemporal.reshape(
|
||||||
|
b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2]
|
||||||
|
)
|
||||||
pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2])
|
pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2])
|
||||||
|
|
||||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id),
|
attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask[0],
|
||||||
|
"labels": labels,
|
||||||
|
"image_grid_spatiotemporal": image_grid_spatiotemporal,
|
||||||
|
"pixel_values_videos": pixel_values_videos,
|
||||||
|
"video_grid_spatiotemporal": video_grid_spatiotemporal,
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
}
|
||||||
|
|
||||||
batch = dict(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask[0],
|
|
||||||
labels=labels,
|
|
||||||
image_grid_thw=image_grid_thw,
|
|
||||||
pixel_values_videos=pixel_values_videos,
|
|
||||||
video_grid_thw=video_grid_thw,
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
)
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def construct_chat_data(self, len_image, raw_lang):
|
def construct_chat_data(self, len_image, raw_lang):
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -134,12 +134,14 @@ class Qwen2VLAProcess:
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
for i in range(len_image):
|
for _i in range(len_image):
|
||||||
messages[0]['content'].append({
|
messages[0]["content"].append(
|
||||||
|
{
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"image": None,
|
"image": None,
|
||||||
})
|
}
|
||||||
messages[0]['content'].append({"type": "text", "text": f""})
|
)
|
||||||
messages[0]['content'][-1]['text'] = raw_lang
|
messages[0]["content"].append({"type": "text", "text": ""})
|
||||||
|
messages[0]["content"][-1]["text"] = raw_lang
|
||||||
|
|
||||||
return messages
|
return messages
|
|
@ -68,7 +68,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||||
within the image size. If None, no cropping is done.
|
within the image size. If None, no cropping is done.
|
||||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||||
mode).
|
mode).
|
||||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||||
`None` means no pretrained weights.
|
`None` means no pretrained weights.
|
||||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||||
|
@ -99,7 +99,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||||
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
||||||
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
||||||
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
|
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
|
||||||
`LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
|
`LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
|
||||||
to False as the original Diffusion Policy implementation does the same.
|
to False as the original Diffusion Policy implementation does the same.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
|
@ -24,9 +23,9 @@ from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||||
from lerobot.common.envs.configs import EnvConfig
|
from lerobot.common.envs.configs import EnvConfig
|
||||||
from lerobot.common.envs.utils import env_to_policy_features
|
from lerobot.common.envs.utils import env_to_policy_features
|
||||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
|
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
|
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
|
@ -83,7 +82,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||||
|
|
||||||
def make_policy(
|
def make_policy(
|
||||||
cfg: PreTrainedConfig,
|
cfg: PreTrainedConfig,
|
||||||
device: str | torch.device,
|
|
||||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||||
env_cfg: EnvConfig | None = None,
|
env_cfg: EnvConfig | None = None,
|
||||||
) -> PreTrainedPolicy:
|
) -> PreTrainedPolicy:
|
||||||
|
@ -95,7 +93,6 @@ def make_policy(
|
||||||
Args:
|
Args:
|
||||||
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
||||||
be loaded with the weights from that path.
|
be loaded with the weights from that path.
|
||||||
device (str): the device to load the policy onto.
|
|
||||||
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
||||||
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
||||||
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
||||||
|
@ -103,7 +100,7 @@ def make_policy(
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: Either ds_meta or env and env_cfg must be provided.
|
ValueError: Either ds_meta or env and env_cfg must be provided.
|
||||||
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility)
|
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
PreTrainedPolicy: _description_
|
PreTrainedPolicy: _description_
|
||||||
|
@ -118,7 +115,7 @@ def make_policy(
|
||||||
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
|
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
|
||||||
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
|
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
|
||||||
# slower than running natively on MPS.
|
# slower than running natively on MPS.
|
||||||
if cfg.type == "vqbet" and str(device) == "mps":
|
if cfg.type == "vqbet" and cfg.device == "mps":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Current implementation of VQBeT does not support `mps` backend. "
|
"Current implementation of VQBeT does not support `mps` backend. "
|
||||||
"Please use `cpu` or `cuda` backend."
|
"Please use `cpu` or `cuda` backend."
|
||||||
|
@ -152,7 +149,7 @@ def make_policy(
|
||||||
# Make a fresh policy.
|
# Make a fresh policy.
|
||||||
policy = policy_cls(**kwargs)
|
policy = policy_cls(**kwargs)
|
||||||
|
|
||||||
policy.to(device)
|
policy.to(cfg.device)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
@ -77,17 +78,29 @@ def create_stats_buffers(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||||
if stats:
|
if stats:
|
||||||
|
if isinstance(stats[key]["mean"], np.ndarray):
|
||||||
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
|
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||||
|
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||||
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
|
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||||
|
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||||
|
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||||
# unnormalization). See the logic here
|
# unnormalization). See the logic here
|
||||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
buffer["mean"].data = stats[key]["mean"].clone()
|
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||||
buffer["std"].data = stats[key]["std"].clone()
|
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
buffer["min"].data = stats[key]["min"].clone()
|
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||||
buffer["max"].data = stats[key]["max"].clone()
|
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
type_ = type(stats[key]["mean"])
|
||||||
|
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||||
|
|
||||||
stats_buffers[key] = buffer
|
stats_buffers[key] = buffer
|
||||||
return stats_buffers
|
return stats_buffers
|
||||||
|
@ -141,6 +154,7 @@ class Normalize(nn.Module):
|
||||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
if key not in batch:
|
if key not in batch:
|
||||||
|
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||||
continue
|
continue
|
||||||
|
|
||||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from lerobot.common.optim.optimizers import AdamWConfig
|
from lerobot.common.optim.optimizers import AdamWConfig
|
||||||
|
@ -76,6 +90,7 @@ class PI0Config(PreTrainedConfig):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
|
# TODO(Steven): Validate device and amp? in all policy configs?
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
if self.n_action_steps > self.chunk_size:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
@ -31,7 +45,7 @@ def main():
|
||||||
|
|
||||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||||
cfg.pretrained_path = ckpt_torch_dir
|
cfg.pretrained_path = ckpt_torch_dir
|
||||||
policy = make_policy(cfg, device, ds_meta=dataset.meta)
|
policy = make_policy(cfg, ds_meta=dataset.meta)
|
||||||
|
|
||||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -87,7 +101,7 @@ def main():
|
||||||
|
|
||||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||||
cfg.pretrained_path = ckpt_torch_dir
|
cfg.pretrained_path = ckpt_torch_dir
|
||||||
policy = make_policy(cfg, device, dataset_meta)
|
policy = make_policy(cfg, dataset_meta)
|
||||||
|
|
||||||
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
||||||
# loss_dict["loss"].backward()
|
# loss_dict["loss"].backward()
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from transformers import GemmaConfig, PaliGemmaConfig
|
from transformers import GemmaConfig, PaliGemmaConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,22 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Convert pi0 parameters from Jax to Pytorch
|
Convert pi0 parameters from Jax to Pytorch
|
||||||
|
|
||||||
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
|
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
|
||||||
and install the required librairies.
|
and install the required libraries.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd ~/code/openpi
|
cd ~/code/openpi
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
|
@ -313,7 +313,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||||
state = self.prepare_state(batch)
|
state = self.prepare_state(batch)
|
||||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
actions_is_pad = batch.get("actions_id_pad")
|
actions_is_pad = batch.get("actions_is_pad")
|
||||||
|
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -1,3 +1,16 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -73,7 +86,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||||
cache_dir: str | Path | None = None,
|
cache_dir: str | Path | None = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
map_location: str = "cpu",
|
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> T:
|
) -> T:
|
||||||
|
@ -98,7 +110,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||||
if os.path.isdir(model_id):
|
if os.path.isdir(model_id):
|
||||||
print("Loading weights from local directory")
|
print("Loading weights from local directory")
|
||||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||||
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
model_file = hf_hub_download(
|
model_file = hf_hub_download(
|
||||||
|
@ -112,13 +124,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||||
token=token,
|
token=token,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||||
except HfHubHTTPError as e:
|
except HfHubHTTPError as e:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
policy.to(map_location)
|
policy.to(config.device)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
|
@ -76,7 +76,7 @@ class TDMPCConfig(PreTrainedConfig):
|
||||||
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
|
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
|
||||||
be zero.
|
be zero.
|
||||||
uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating
|
uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating
|
||||||
trajectory values (this is the λ coeffiecient in eqn 4 of FOWM).
|
trajectory values (this is the λ coefficient in eqn 4 of FOWM).
|
||||||
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
|
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
|
||||||
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
||||||
elites, when updating the gaussian parameters for CEM.
|
elites, when updating the gaussian parameters for CEM.
|
||||||
|
@ -165,7 +165,7 @@ class TDMPCConfig(PreTrainedConfig):
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
if self.n_gaussian_samples <= 0:
|
if self.n_gaussian_samples <= 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||||
)
|
)
|
||||||
if self.normalization_mapping["ACTION"] is not NormalizationMode.MIN_MAX:
|
if self.normalization_mapping["ACTION"] is not NormalizationMode.MIN_MAX:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -66,7 +66,7 @@ class VQBeTConfig(PreTrainedConfig):
|
||||||
within the image size. If None, no cropping is done.
|
within the image size. If None, no cropping is done.
|
||||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||||
mode).
|
mode).
|
||||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||||
`None` means no pretrained weights.
|
`None` means no pretrained weights.
|
||||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||||
|
|
|
@ -485,7 +485,7 @@ class VQBeTHead(nn.Module):
|
||||||
def forward(self, x, **kwargs) -> dict:
|
def forward(self, x, **kwargs) -> dict:
|
||||||
# N is the batch size, and T is number of action query tokens, which are process through same GPT
|
# N is the batch size, and T is number of action query tokens, which are process through same GPT
|
||||||
N, T, _ = x.shape
|
N, T, _ = x.shape
|
||||||
# we calculate N and T side parallely. Thus, the dimensions would be
|
# we calculate N and T side parallelly. Thus, the dimensions would be
|
||||||
# (batch size * number of action query tokens, action chunk size, action dimension)
|
# (batch size * number of action query tokens, action chunk size, action dimension)
|
||||||
x = einops.rearrange(x, "N T WA -> (N T) WA")
|
x = einops.rearrange(x, "N T WA -> (N T) WA")
|
||||||
|
|
||||||
|
@ -772,7 +772,7 @@ class VqVae(nn.Module):
|
||||||
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
|
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
|
||||||
The vq_layer uses residual VQs.
|
The vq_layer uses residual VQs.
|
||||||
|
|
||||||
This class contains functions for training the encoder and decoder along with the residual VQ layer (for trainign phase 1),
|
This class contains functions for training the encoder and decoder along with the residual VQ layer (for training phase 1),
|
||||||
as well as functions to help BeT training part in training phase 2.
|
as well as functions to help BeT training part in training phase 2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
This file is part of a VQ-BeT that utilizes code from the following repositories:
|
This file is part of a VQ-BeT that utilizes code from the following repositories:
|
||||||
|
|
||||||
- Vector Quantize PyTorch code is licensed under the MIT License:
|
- Vector Quantize PyTorch code is licensed under the MIT License:
|
||||||
Origianl source: https://github.com/lucidrains/vector-quantize-pytorch
|
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||||
|
|
||||||
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||||
Original source: https://github.com/karpathy/nanoGPT
|
Original source: https://github.com/karpathy/nanoGPT
|
||||||
|
@ -289,7 +289,7 @@ class GPT(nn.Module):
|
||||||
This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
||||||
|
|
||||||
- Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
- Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
||||||
Origianl source: https://github.com/lucidrains/vector-quantize-pytorch
|
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||||
|
|
||||||
- The vector-quantize-pytorch code is licensed under the MIT License:
|
- The vector-quantize-pytorch code is licensed under the MIT License:
|
||||||
|
|
||||||
|
@ -1349,9 +1349,9 @@ class EuclideanCodebook(nn.Module):
|
||||||
|
|
||||||
# calculate distributed variance
|
# calculate distributed variance
|
||||||
|
|
||||||
variance_numer = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
|
variance_number = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
|
||||||
distributed.all_reduce(variance_numer)
|
distributed.all_reduce(variance_number)
|
||||||
batch_variance = variance_numer / num_vectors
|
batch_variance = variance_number / num_vectors
|
||||||
|
|
||||||
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
|
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This file contains utilities for recording frames from Intel Realsense cameras.
|
This file contains utilities for recording frames from Intel Realsense cameras.
|
||||||
"""
|
"""
|
||||||
|
@ -34,7 +48,7 @@ def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
|
||||||
connected to the computer.
|
connected to the computer.
|
||||||
"""
|
"""
|
||||||
if mock:
|
if mock:
|
||||||
import tests.mock_pyrealsense2 as rs
|
import tests.cameras.mock_pyrealsense2 as rs
|
||||||
else:
|
else:
|
||||||
import pyrealsense2 as rs
|
import pyrealsense2 as rs
|
||||||
|
|
||||||
|
@ -86,7 +100,7 @@ def save_images_from_cameras(
|
||||||
serial_numbers = [cam["serial_number"] for cam in camera_infos]
|
serial_numbers = [cam["serial_number"] for cam in camera_infos]
|
||||||
|
|
||||||
if mock:
|
if mock:
|
||||||
import tests.mock_cv2 as cv2
|
import tests.cameras.mock_cv2 as cv2
|
||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
@ -100,7 +114,7 @@ def save_images_from_cameras(
|
||||||
camera = IntelRealSenseCamera(config)
|
camera = IntelRealSenseCamera(config)
|
||||||
camera.connect()
|
camera.connect()
|
||||||
print(
|
print(
|
||||||
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
|
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})"
|
||||||
)
|
)
|
||||||
cameras.append(camera)
|
cameras.append(camera)
|
||||||
|
|
||||||
|
@ -210,9 +224,20 @@ class IntelRealSenseCamera:
|
||||||
self.serial_number = self.find_serial_number_from_name(config.name)
|
self.serial_number = self.find_serial_number_from_name(config.name)
|
||||||
else:
|
else:
|
||||||
self.serial_number = config.serial_number
|
self.serial_number = config.serial_number
|
||||||
self.fps = config.fps
|
|
||||||
|
# Store the raw (capture) resolution from the config.
|
||||||
|
self.capture_width = config.width
|
||||||
|
self.capture_height = config.height
|
||||||
|
|
||||||
|
# If rotated by ±90, swap width and height.
|
||||||
|
if config.rotation in [-90, 90]:
|
||||||
|
self.width = config.height
|
||||||
|
self.height = config.width
|
||||||
|
else:
|
||||||
self.width = config.width
|
self.width = config.width
|
||||||
self.height = config.height
|
self.height = config.height
|
||||||
|
|
||||||
|
self.fps = config.fps
|
||||||
self.channels = config.channels
|
self.channels = config.channels
|
||||||
self.color_mode = config.color_mode
|
self.color_mode = config.color_mode
|
||||||
self.use_depth = config.use_depth
|
self.use_depth = config.use_depth
|
||||||
|
@ -228,11 +253,10 @@ class IntelRealSenseCamera:
|
||||||
self.logs = {}
|
self.logs = {}
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_cv2 as cv2
|
import tests.cameras.mock_cv2 as cv2
|
||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
# TODO(alibets): Do we keep original width/height or do we define them after rotation?
|
|
||||||
self.rotation = None
|
self.rotation = None
|
||||||
if config.rotation == -90:
|
if config.rotation == -90:
|
||||||
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||||
|
@ -263,22 +287,26 @@ class IntelRealSenseCamera:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_pyrealsense2 as rs
|
import tests.cameras.mock_pyrealsense2 as rs
|
||||||
else:
|
else:
|
||||||
import pyrealsense2 as rs
|
import pyrealsense2 as rs
|
||||||
|
|
||||||
config = rs.config()
|
config = rs.config()
|
||||||
config.enable_device(str(self.serial_number))
|
config.enable_device(str(self.serial_number))
|
||||||
|
|
||||||
if self.fps and self.width and self.height:
|
if self.fps and self.capture_width and self.capture_height:
|
||||||
# TODO(rcadene): can we set rgb8 directly?
|
# TODO(rcadene): can we set rgb8 directly?
|
||||||
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
|
config.enable_stream(
|
||||||
|
rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
config.enable_stream(rs.stream.color)
|
config.enable_stream(rs.stream.color)
|
||||||
|
|
||||||
if self.use_depth:
|
if self.use_depth:
|
||||||
if self.fps and self.width and self.height:
|
if self.fps and self.capture_width and self.capture_height:
|
||||||
config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
|
config.enable_stream(
|
||||||
|
rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
config.enable_stream(rs.stream.depth)
|
config.enable_stream(rs.stream.depth)
|
||||||
|
|
||||||
|
@ -316,18 +344,18 @@ class IntelRealSenseCamera:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
||||||
)
|
)
|
||||||
if self.width is not None and self.width != actual_width:
|
if self.capture_width is not None and self.capture_width != actual_width:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
f"Can't set {self.capture_width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
||||||
)
|
)
|
||||||
if self.height is not None and self.height != actual_height:
|
if self.capture_height is not None and self.capture_height != actual_height:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
f"Can't set {self.capture_height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.fps = round(actual_fps)
|
self.fps = round(actual_fps)
|
||||||
self.width = round(actual_width)
|
self.capture_width = round(actual_width)
|
||||||
self.height = round(actual_height)
|
self.capture_height = round(actual_height)
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
|
|
||||||
|
@ -347,7 +375,7 @@ class IntelRealSenseCamera:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_cv2 as cv2
|
import tests.cameras.mock_cv2 as cv2
|
||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
@ -373,7 +401,7 @@ class IntelRealSenseCamera:
|
||||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
h, w, _ = color_image.shape
|
h, w, _ = color_image.shape
|
||||||
if h != self.height or w != self.width:
|
if h != self.capture_height or w != self.capture_width:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||||
)
|
)
|
||||||
|
@ -395,7 +423,7 @@ class IntelRealSenseCamera:
|
||||||
depth_map = np.asanyarray(depth_frame.get_data())
|
depth_map = np.asanyarray(depth_frame.get_data())
|
||||||
|
|
||||||
h, w = depth_map.shape
|
h, w = depth_map.shape
|
||||||
if h != self.height or w != self.width:
|
if h != self.capture_height or w != self.capture_width:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
|
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
|
||||||
"""
|
"""
|
||||||
|
@ -66,7 +80,7 @@ def _find_cameras(
|
||||||
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
|
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
|
||||||
) -> list[int | str]:
|
) -> list[int | str]:
|
||||||
if mock:
|
if mock:
|
||||||
import tests.mock_cv2 as cv2
|
import tests.cameras.mock_cv2 as cv2
|
||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
@ -130,8 +144,8 @@ def save_images_from_cameras(
|
||||||
camera = OpenCVCamera(config)
|
camera = OpenCVCamera(config)
|
||||||
camera.connect()
|
camera.connect()
|
||||||
print(
|
print(
|
||||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
|
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.capture_width}, "
|
||||||
f"height={camera.height}, color_mode={camera.color_mode})"
|
f"height={camera.capture_height}, color_mode={camera.color_mode})"
|
||||||
)
|
)
|
||||||
cameras.append(camera)
|
cameras.append(camera)
|
||||||
|
|
||||||
|
@ -230,9 +244,19 @@ class OpenCVCamera:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
|
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
|
||||||
|
|
||||||
self.fps = config.fps
|
# Store the raw (capture) resolution from the config.
|
||||||
|
self.capture_width = config.width
|
||||||
|
self.capture_height = config.height
|
||||||
|
|
||||||
|
# If rotated by ±90, swap width and height.
|
||||||
|
if config.rotation in [-90, 90]:
|
||||||
|
self.width = config.height
|
||||||
|
self.height = config.width
|
||||||
|
else:
|
||||||
self.width = config.width
|
self.width = config.width
|
||||||
self.height = config.height
|
self.height = config.height
|
||||||
|
|
||||||
|
self.fps = config.fps
|
||||||
self.channels = config.channels
|
self.channels = config.channels
|
||||||
self.color_mode = config.color_mode
|
self.color_mode = config.color_mode
|
||||||
self.mock = config.mock
|
self.mock = config.mock
|
||||||
|
@ -245,11 +269,10 @@ class OpenCVCamera:
|
||||||
self.logs = {}
|
self.logs = {}
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_cv2 as cv2
|
import tests.cameras.mock_cv2 as cv2
|
||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
# TODO(aliberts): Do we keep original width/height or do we define them after rotation?
|
|
||||||
self.rotation = None
|
self.rotation = None
|
||||||
if config.rotation == -90:
|
if config.rotation == -90:
|
||||||
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||||
|
@ -263,7 +286,7 @@ class OpenCVCamera:
|
||||||
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_cv2 as cv2
|
import tests.cameras.mock_cv2 as cv2
|
||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
@ -271,10 +294,20 @@ class OpenCVCamera:
|
||||||
# when other threads are used to save the images.
|
# when other threads are used to save the images.
|
||||||
cv2.setNumThreads(1)
|
cv2.setNumThreads(1)
|
||||||
|
|
||||||
|
backend = (
|
||||||
|
cv2.CAP_V4L2
|
||||||
|
if platform.system() == "Linux"
|
||||||
|
else cv2.CAP_DSHOW
|
||||||
|
if platform.system() == "Windows"
|
||||||
|
else cv2.CAP_AVFOUNDATION
|
||||||
|
if platform.system() == "Darwin"
|
||||||
|
else cv2.CAP_ANY
|
||||||
|
)
|
||||||
|
|
||||||
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
||||||
# First create a temporary camera trying to access `camera_index`,
|
# First create a temporary camera trying to access `camera_index`,
|
||||||
# and verify it is a valid camera by calling `isOpened`.
|
# and verify it is a valid camera by calling `isOpened`.
|
||||||
tmp_camera = cv2.VideoCapture(camera_idx)
|
tmp_camera = cv2.VideoCapture(camera_idx, backend)
|
||||||
is_camera_open = tmp_camera.isOpened()
|
is_camera_open = tmp_camera.isOpened()
|
||||||
# Release camera to make it accessible for `find_camera_indices`
|
# Release camera to make it accessible for `find_camera_indices`
|
||||||
tmp_camera.release()
|
tmp_camera.release()
|
||||||
|
@ -297,14 +330,14 @@ class OpenCVCamera:
|
||||||
# Secondly, create the camera that will be used downstream.
|
# Secondly, create the camera that will be used downstream.
|
||||||
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
|
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
|
||||||
# needs to be re-created.
|
# needs to be re-created.
|
||||||
self.camera = cv2.VideoCapture(camera_idx)
|
self.camera = cv2.VideoCapture(camera_idx, backend)
|
||||||
|
|
||||||
if self.fps is not None:
|
if self.fps is not None:
|
||||||
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
||||||
if self.width is not None:
|
if self.capture_width is not None:
|
||||||
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
|
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.capture_width)
|
||||||
if self.height is not None:
|
if self.capture_height is not None:
|
||||||
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
|
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.capture_height)
|
||||||
|
|
||||||
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
|
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
|
||||||
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
|
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
|
||||||
|
@ -316,19 +349,22 @@ class OpenCVCamera:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
|
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
|
||||||
)
|
)
|
||||||
if self.width is not None and not math.isclose(self.width, actual_width, rel_tol=1e-3):
|
if self.capture_width is not None and not math.isclose(
|
||||||
|
self.capture_width, actual_width, rel_tol=1e-3
|
||||||
|
):
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
f"Can't set {self.capture_width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
||||||
)
|
)
|
||||||
if self.height is not None and not math.isclose(self.height, actual_height, rel_tol=1e-3):
|
if self.capture_height is not None and not math.isclose(
|
||||||
|
self.capture_height, actual_height, rel_tol=1e-3
|
||||||
|
):
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
f"Can't set {self.capture_height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.fps = round(actual_fps)
|
self.fps = round(actual_fps)
|
||||||
self.width = round(actual_width)
|
self.capture_width = round(actual_width)
|
||||||
self.height = round(actual_height)
|
self.capture_height = round(actual_height)
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
|
|
||||||
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
|
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
|
||||||
|
@ -362,14 +398,14 @@ class OpenCVCamera:
|
||||||
# so we convert the image color from BGR to RGB.
|
# so we convert the image color from BGR to RGB.
|
||||||
if requested_color_mode == "rgb":
|
if requested_color_mode == "rgb":
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_cv2 as cv2
|
import tests.cameras.mock_cv2 as cv2
|
||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
h, w, _ = color_image.shape
|
h, w, _ = color_image.shape
|
||||||
if h != self.height or w != self.width:
|
if h != self.capture_height or w != self.capture_width:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -31,7 +45,7 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[C
|
||||||
|
|
||||||
cameras[key] = IntelRealSenseCamera(cfg)
|
cameras[key] = IntelRealSenseCamera(cfg)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
|
||||||
|
|
||||||
return cameras
|
return cameras
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,25 @@
|
||||||
import logging
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -43,11 +54,6 @@ class RecordControlConfig(ControlConfig):
|
||||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||||
root: str | Path | None = None
|
root: str | Path | None = None
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
|
|
||||||
device: str | None = None # cuda | cpu | mps
|
|
||||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
|
||||||
# automatic gradient scaling is used.
|
|
||||||
use_amp: bool | None = None
|
|
||||||
# Limit the frames per second. By default, uses the policy fps.
|
# Limit the frames per second. By default, uses the policy fps.
|
||||||
fps: int | None = None
|
fps: int | None = None
|
||||||
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
||||||
|
@ -60,15 +66,13 @@ class RecordControlConfig(ControlConfig):
|
||||||
num_episodes: int = 50
|
num_episodes: int = 50
|
||||||
# Encode frames in the dataset into video
|
# Encode frames in the dataset into video
|
||||||
video: bool = True
|
video: bool = True
|
||||||
# By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.
|
|
||||||
run_compute_stats: bool = True
|
|
||||||
# Upload dataset to Hugging Face hub.
|
# Upload dataset to Hugging Face hub.
|
||||||
push_to_hub: bool = True
|
push_to_hub: bool = True
|
||||||
# Upload on private repository on the Hugging Face hub.
|
# Upload on private repository on the Hugging Face hub.
|
||||||
private: bool = False
|
private: bool = False
|
||||||
# Add tags to your dataset on the hub.
|
# Add tags to your dataset on the hub.
|
||||||
tags: list[str] | None = None
|
tags: list[str] | None = None
|
||||||
# Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only;
|
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||||
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||||
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||||
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||||
|
@ -83,9 +87,6 @@ class RecordControlConfig(ControlConfig):
|
||||||
play_sounds: bool = True
|
play_sounds: bool = True
|
||||||
# Resume recording on an existing dataset.
|
# Resume recording on an existing dataset.
|
||||||
resume: bool = False
|
resume: bool = False
|
||||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
|
||||||
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
|
|
||||||
local_files_only: bool = False
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||||
|
@ -95,27 +96,6 @@ class RecordControlConfig(ControlConfig):
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
self.policy.pretrained_path = policy_path
|
self.policy.pretrained_path = policy_path
|
||||||
|
|
||||||
# When no device or use_amp are given, use the one from training config.
|
|
||||||
if self.device is None or self.use_amp is None:
|
|
||||||
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
|
|
||||||
if self.device is None:
|
|
||||||
self.device = train_cfg.device
|
|
||||||
if self.use_amp is None:
|
|
||||||
self.use_amp = train_cfg.use_amp
|
|
||||||
|
|
||||||
# Automatically switch to available device if necessary
|
|
||||||
if not is_torch_device_available(self.device):
|
|
||||||
auto_device = auto_select_torch_device()
|
|
||||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
|
||||||
self.device = auto_device
|
|
||||||
|
|
||||||
# Automatically deactivate AMP if necessary
|
|
||||||
if self.use_amp and not is_amp_available(self.device):
|
|
||||||
logging.warning(
|
|
||||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
|
||||||
)
|
|
||||||
self.use_amp = False
|
|
||||||
|
|
||||||
|
|
||||||
@ControlConfig.register_subclass("replay")
|
@ControlConfig.register_subclass("replay")
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -130,9 +110,12 @@ class ReplayControlConfig(ControlConfig):
|
||||||
fps: int | None = None
|
fps: int | None = None
|
||||||
# Use vocal synthesis to read events.
|
# Use vocal synthesis to read events.
|
||||||
play_sounds: bool = True
|
play_sounds: bool = True
|
||||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
|
||||||
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
|
|
||||||
local_files_only: bool = False
|
@ControlConfig.register_subclass("remote_robot")
|
||||||
|
@dataclass
|
||||||
|
class RemoteRobotConfig(ControlConfig):
|
||||||
|
log_interval: int = 100
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
########################################################################################
|
########################################################################################
|
||||||
# Utilities
|
# Utilities
|
||||||
########################################################################################
|
########################################################################################
|
||||||
|
@ -12,13 +26,13 @@ from functools import cache
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import get_features_from_robot
|
from lerobot.common.datasets.utils import get_features_from_robot
|
||||||
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from lerobot.common.robot_devices.utils import busy_wait
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
||||||
|
@ -180,9 +194,8 @@ def record_episode(
|
||||||
episode_time_s,
|
episode_time_s,
|
||||||
display_cameras,
|
display_cameras,
|
||||||
policy,
|
policy,
|
||||||
device,
|
|
||||||
use_amp,
|
|
||||||
fps,
|
fps,
|
||||||
|
single_task,
|
||||||
):
|
):
|
||||||
control_loop(
|
control_loop(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
|
@ -191,10 +204,9 @@ def record_episode(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
events=events,
|
events=events,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
device=device,
|
|
||||||
use_amp=use_amp,
|
|
||||||
fps=fps,
|
fps=fps,
|
||||||
teleoperate=policy is None,
|
teleoperate=policy is None,
|
||||||
|
single_task=single_task,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -206,10 +218,9 @@ def control_loop(
|
||||||
display_cameras=False,
|
display_cameras=False,
|
||||||
dataset: LeRobotDataset | None = None,
|
dataset: LeRobotDataset | None = None,
|
||||||
events=None,
|
events=None,
|
||||||
policy=None,
|
policy: PreTrainedPolicy = None,
|
||||||
device: torch.device | str | None = None,
|
|
||||||
use_amp: bool | None = None,
|
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
|
single_task: str | None = None,
|
||||||
):
|
):
|
||||||
# TODO(rcadene): Add option to record logs
|
# TODO(rcadene): Add option to record logs
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
|
@ -224,12 +235,12 @@ def control_loop(
|
||||||
if teleoperate and policy is not None:
|
if teleoperate and policy is not None:
|
||||||
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
||||||
|
|
||||||
|
if dataset is not None and single_task is None:
|
||||||
|
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||||||
|
|
||||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||||
|
|
||||||
if isinstance(device, str):
|
|
||||||
device = get_safe_torch_device(device)
|
|
||||||
|
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
while timestamp < control_time_s:
|
while timestamp < control_time_s:
|
||||||
|
@ -241,14 +252,16 @@ def control_loop(
|
||||||
observation = robot.capture_observation()
|
observation = robot.capture_observation()
|
||||||
|
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
pred_action = predict_action(observation, policy, device, use_amp)
|
pred_action = predict_action(
|
||||||
|
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||||
|
)
|
||||||
# Action can eventually be clipped using `max_relative_target`,
|
# Action can eventually be clipped using `max_relative_target`,
|
||||||
# so action actually sent is saved in the dataset.
|
# so action actually sent is saved in the dataset.
|
||||||
action = robot.send_action(pred_action)
|
action = robot.send_action(pred_action)
|
||||||
action = {"action": action}
|
action = {"action": action}
|
||||||
|
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
frame = {**observation, **action}
|
frame = {**observation, **action, "task": single_task}
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
if display_cameras and not is_headless():
|
if display_cameras and not is_headless():
|
||||||
|
@ -270,24 +283,18 @@ def control_loop(
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def reset_environment(robot, events, reset_time_s):
|
def reset_environment(robot, events, reset_time_s, fps):
|
||||||
# TODO(rcadene): refactor warmup_record and reset_environment
|
# TODO(rcadene): refactor warmup_record and reset_environment
|
||||||
# TODO(alibets): allow for teleop during reset
|
|
||||||
if has_method(robot, "teleop_safety_stop"):
|
if has_method(robot, "teleop_safety_stop"):
|
||||||
robot.teleop_safety_stop()
|
robot.teleop_safety_stop()
|
||||||
|
|
||||||
timestamp = 0
|
control_loop(
|
||||||
start_vencod_t = time.perf_counter()
|
robot=robot,
|
||||||
|
control_time_s=reset_time_s,
|
||||||
# Wait if necessary
|
events=events,
|
||||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
fps=fps,
|
||||||
while timestamp < reset_time_s:
|
teleoperate=True,
|
||||||
time.sleep(1)
|
)
|
||||||
timestamp = time.perf_counter() - start_vencod_t
|
|
||||||
pbar.update(1)
|
|
||||||
if events["exit_early"]:
|
|
||||||
events["exit_early"] = False
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def stop_recording(robot, listener, display_cameras):
|
def stop_recording(robot, listener, display_cameras):
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
@ -242,7 +256,7 @@ class DriveMode(enum.Enum):
|
||||||
class CalibrationMode(enum.Enum):
|
class CalibrationMode(enum.Enum):
|
||||||
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
||||||
DEGREE = 0
|
DEGREE = 0
|
||||||
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
|
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||||
LINEAR = 1
|
LINEAR = 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -318,7 +332,7 @@ class DynamixelMotorsBus:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_dynamixel_sdk as dxl
|
import tests.motors.mock_dynamixel_sdk as dxl
|
||||||
else:
|
else:
|
||||||
import dynamixel_sdk as dxl
|
import dynamixel_sdk as dxl
|
||||||
|
|
||||||
|
@ -342,7 +356,7 @@ class DynamixelMotorsBus:
|
||||||
|
|
||||||
def reconnect(self):
|
def reconnect(self):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_dynamixel_sdk as dxl
|
import tests.motors.mock_dynamixel_sdk as dxl
|
||||||
else:
|
else:
|
||||||
import dynamixel_sdk as dxl
|
import dynamixel_sdk as dxl
|
||||||
|
|
||||||
|
@ -610,7 +624,7 @@ class DynamixelMotorsBus:
|
||||||
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
|
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
|
||||||
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
|
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
|
||||||
|
|
||||||
# Substract the homing offsets to come back to actual motor range of values
|
# Subtract the homing offsets to come back to actual motor range of values
|
||||||
# which can be arbitrary.
|
# which can be arbitrary.
|
||||||
values[i] -= homing_offset
|
values[i] -= homing_offset
|
||||||
|
|
||||||
|
@ -632,7 +646,7 @@ class DynamixelMotorsBus:
|
||||||
|
|
||||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_dynamixel_sdk as dxl
|
import tests.motors.mock_dynamixel_sdk as dxl
|
||||||
else:
|
else:
|
||||||
import dynamixel_sdk as dxl
|
import dynamixel_sdk as dxl
|
||||||
|
|
||||||
|
@ -677,7 +691,7 @@ class DynamixelMotorsBus:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_dynamixel_sdk as dxl
|
import tests.motors.mock_dynamixel_sdk as dxl
|
||||||
else:
|
else:
|
||||||
import dynamixel_sdk as dxl
|
import dynamixel_sdk as dxl
|
||||||
|
|
||||||
|
@ -743,7 +757,7 @@ class DynamixelMotorsBus:
|
||||||
|
|
||||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_dynamixel_sdk as dxl
|
import tests.motors.mock_dynamixel_sdk as dxl
|
||||||
else:
|
else:
|
||||||
import dynamixel_sdk as dxl
|
import dynamixel_sdk as dxl
|
||||||
|
|
||||||
|
@ -779,7 +793,7 @@ class DynamixelMotorsBus:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_dynamixel_sdk as dxl
|
import tests.motors.mock_dynamixel_sdk as dxl
|
||||||
else:
|
else:
|
||||||
import dynamixel_sdk as dxl
|
import dynamixel_sdk as dxl
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
@ -221,7 +235,7 @@ class DriveMode(enum.Enum):
|
||||||
class CalibrationMode(enum.Enum):
|
class CalibrationMode(enum.Enum):
|
||||||
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
||||||
DEGREE = 0
|
DEGREE = 0
|
||||||
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
|
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||||
LINEAR = 1
|
LINEAR = 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -299,7 +313,7 @@ class FeetechMotorsBus:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_scservo_sdk as scs
|
import tests.motors.mock_scservo_sdk as scs
|
||||||
else:
|
else:
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
@ -323,7 +337,7 @@ class FeetechMotorsBus:
|
||||||
|
|
||||||
def reconnect(self):
|
def reconnect(self):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_scservo_sdk as scs
|
import tests.motors.mock_scservo_sdk as scs
|
||||||
else:
|
else:
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
@ -591,7 +605,7 @@ class FeetechMotorsBus:
|
||||||
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
|
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
|
||||||
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
|
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
|
||||||
|
|
||||||
# Substract the homing offsets to come back to actual motor range of values
|
# Subtract the homing offsets to come back to actual motor range of values
|
||||||
# which can be arbitrary.
|
# which can be arbitrary.
|
||||||
values[i] -= homing_offset
|
values[i] -= homing_offset
|
||||||
|
|
||||||
|
@ -632,7 +646,7 @@ class FeetechMotorsBus:
|
||||||
track["prev"][idx] = values[i]
|
track["prev"][idx] = values[i]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Detect a full rotation occured
|
# Detect a full rotation occurred
|
||||||
if abs(track["prev"][idx] - values[i]) > 2048:
|
if abs(track["prev"][idx] - values[i]) > 2048:
|
||||||
# Position went below 0 and got reset to 4095
|
# Position went below 0 and got reset to 4095
|
||||||
if track["prev"][idx] < values[i]:
|
if track["prev"][idx] < values[i]:
|
||||||
|
@ -650,7 +664,7 @@ class FeetechMotorsBus:
|
||||||
|
|
||||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_scservo_sdk as scs
|
import tests.motors.mock_scservo_sdk as scs
|
||||||
else:
|
else:
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
@ -688,7 +702,7 @@ class FeetechMotorsBus:
|
||||||
|
|
||||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_scservo_sdk as scs
|
import tests.motors.mock_scservo_sdk as scs
|
||||||
else:
|
else:
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
@ -768,7 +782,7 @@ class FeetechMotorsBus:
|
||||||
|
|
||||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_scservo_sdk as scs
|
import tests.motors.mock_scservo_sdk as scs
|
||||||
else:
|
else:
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
@ -804,7 +818,7 @@ class FeetechMotorsBus:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_scservo_sdk as scs
|
import tests.motors.mock_scservo_sdk as scs
|
||||||
else:
|
else:
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from lerobot.common.robot_devices.motors.configs import (
|
from lerobot.common.robot_devices.motors.configs import (
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
@ -514,3 +528,86 @@ class StretchRobotConfig(RobotConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
mock: bool = False
|
mock: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@RobotConfig.register_subclass("lekiwi")
|
||||||
|
@dataclass
|
||||||
|
class LeKiwiRobotConfig(RobotConfig):
|
||||||
|
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||||
|
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||||
|
# the number of motors in your follower arms.
|
||||||
|
max_relative_target: int | None = None
|
||||||
|
|
||||||
|
# Network Configuration
|
||||||
|
ip: str = "192.168.0.193"
|
||||||
|
port: int = 5555
|
||||||
|
video_port: int = 5556
|
||||||
|
|
||||||
|
cameras: dict[str, CameraConfig] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"front": OpenCVCameraConfig(
|
||||||
|
camera_index="/dev/video0", fps=30, width=640, height=480, rotation=90
|
||||||
|
),
|
||||||
|
"wrist": OpenCVCameraConfig(
|
||||||
|
camera_index="/dev/video2", fps=30, width=640, height=480, rotation=180
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
calibration_dir: str = ".cache/calibration/lekiwi"
|
||||||
|
|
||||||
|
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"main": FeetechMotorsBusConfig(
|
||||||
|
port="/dev/tty.usbmodem585A0077581",
|
||||||
|
motors={
|
||||||
|
# name: (index, model)
|
||||||
|
"shoulder_pan": [1, "sts3215"],
|
||||||
|
"shoulder_lift": [2, "sts3215"],
|
||||||
|
"elbow_flex": [3, "sts3215"],
|
||||||
|
"wrist_flex": [4, "sts3215"],
|
||||||
|
"wrist_roll": [5, "sts3215"],
|
||||||
|
"gripper": [6, "sts3215"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"main": FeetechMotorsBusConfig(
|
||||||
|
port="/dev/ttyACM0",
|
||||||
|
motors={
|
||||||
|
# name: (index, model)
|
||||||
|
"shoulder_pan": [1, "sts3215"],
|
||||||
|
"shoulder_lift": [2, "sts3215"],
|
||||||
|
"elbow_flex": [3, "sts3215"],
|
||||||
|
"wrist_flex": [4, "sts3215"],
|
||||||
|
"wrist_roll": [5, "sts3215"],
|
||||||
|
"gripper": [6, "sts3215"],
|
||||||
|
"left_wheel": (7, "sts3215"),
|
||||||
|
"back_wheel": (8, "sts3215"),
|
||||||
|
"right_wheel": (9, "sts3215"),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
teleop_keys: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
# Movement
|
||||||
|
"forward": "w",
|
||||||
|
"backward": "s",
|
||||||
|
"left": "a",
|
||||||
|
"right": "d",
|
||||||
|
"rotate_left": "z",
|
||||||
|
"rotate_right": "x",
|
||||||
|
# Speed control
|
||||||
|
"speed_up": "r",
|
||||||
|
"speed_down": "f",
|
||||||
|
# quit teleop
|
||||||
|
"quit": "q",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock: bool = False
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""Logic to calibrate a robot arm built with dynamixel motors"""
|
"""Logic to calibrate a robot arm built with dynamixel motors"""
|
||||||
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
||||||
|
|
||||||
|
@ -87,7 +101,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||||
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
|
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
|
||||||
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
|
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
|
||||||
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
|
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
|
||||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
|
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
|
||||||
# of the previous motor in the kinetic chain.
|
# of the previous motor in the kinetic chain.
|
||||||
print("\nMove arm to rotated target position")
|
print("\nMove arm to rotated target position")
|
||||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||||
|
@ -115,7 +129,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||||
|
|
||||||
# TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml?
|
# TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml?
|
||||||
if robot_type in ["aloha"] and "gripper" in arm.motor_names:
|
if robot_type in ["aloha"] and "gripper" in arm.motor_names:
|
||||||
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
|
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||||
calib_idx = arm.motor_names.index("gripper")
|
calib_idx = arm.motor_names.index("gripper")
|
||||||
calib_mode[calib_idx] = CalibrationMode.LINEAR.name
|
calib_mode[calib_idx] = CalibrationMode.LINEAR.name
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""Logic to calibrate a robot arm built with feetech motors"""
|
"""Logic to calibrate a robot arm built with feetech motors"""
|
||||||
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
||||||
|
|
||||||
|
@ -443,7 +457,7 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
|
||||||
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
|
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
|
||||||
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
|
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
|
||||||
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
|
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
|
||||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
|
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
|
||||||
# of the previous motor in the kinetic chain.
|
# of the previous motor in the kinetic chain.
|
||||||
print("\nMove arm to rotated target position")
|
print("\nMove arm to rotated target position")
|
||||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||||
|
|
|
@ -0,0 +1,224 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from lerobot.common.robot_devices.robots.mobile_manipulator import LeKiwi
|
||||||
|
|
||||||
|
|
||||||
|
def setup_zmq_sockets(config):
|
||||||
|
context = zmq.Context()
|
||||||
|
cmd_socket = context.socket(zmq.PULL)
|
||||||
|
cmd_socket.setsockopt(zmq.CONFLATE, 1)
|
||||||
|
cmd_socket.bind(f"tcp://*:{config.port}")
|
||||||
|
|
||||||
|
video_socket = context.socket(zmq.PUSH)
|
||||||
|
video_socket.setsockopt(zmq.CONFLATE, 1)
|
||||||
|
video_socket.bind(f"tcp://*:{config.video_port}")
|
||||||
|
|
||||||
|
return context, cmd_socket, video_socket
|
||||||
|
|
||||||
|
|
||||||
|
def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event):
|
||||||
|
while not stop_event.is_set():
|
||||||
|
local_dict = {}
|
||||||
|
for name, cam in cameras.items():
|
||||||
|
frame = cam.async_read()
|
||||||
|
ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
||||||
|
if ret:
|
||||||
|
local_dict[name] = base64.b64encode(buffer).decode("utf-8")
|
||||||
|
else:
|
||||||
|
local_dict[name] = ""
|
||||||
|
with images_lock:
|
||||||
|
latest_images_dict.update(local_dict)
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
|
||||||
|
def calibrate_follower_arm(motors_bus, calib_dir_str):
|
||||||
|
"""
|
||||||
|
Calibrates the follower arm. Attempts to load an existing calibration file;
|
||||||
|
if not found, runs manual calibration and saves the result.
|
||||||
|
"""
|
||||||
|
calib_dir = Path(calib_dir_str)
|
||||||
|
calib_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
calib_file = calib_dir / "main_follower.json"
|
||||||
|
try:
|
||||||
|
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
|
||||||
|
except ImportError:
|
||||||
|
print("[WARNING] Calibration function not available. Skipping calibration.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if calib_file.exists():
|
||||||
|
with open(calib_file) as f:
|
||||||
|
calibration = json.load(f)
|
||||||
|
print(f"[INFO] Loaded calibration from {calib_file}")
|
||||||
|
else:
|
||||||
|
print("[INFO] Calibration file not found. Running manual calibration...")
|
||||||
|
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
||||||
|
print(f"[INFO] Calibration complete. Saving to {calib_file}")
|
||||||
|
with open(calib_file, "w") as f:
|
||||||
|
json.dump(calibration, f)
|
||||||
|
try:
|
||||||
|
motors_bus.set_calibration(calibration)
|
||||||
|
print("[INFO] Applied calibration for follower arm.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARNING] Could not apply calibration: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_lekiwi(robot_config):
|
||||||
|
"""
|
||||||
|
Runs the LeKiwi robot:
|
||||||
|
- Sets up cameras and connects them.
|
||||||
|
- Initializes the follower arm motors.
|
||||||
|
- Calibrates the follower arm if necessary.
|
||||||
|
- Creates ZeroMQ sockets for receiving commands and streaming observations.
|
||||||
|
- Processes incoming commands (arm and wheel commands) and sends back sensor and camera data.
|
||||||
|
"""
|
||||||
|
# Import helper functions and classes
|
||||||
|
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode
|
||||||
|
|
||||||
|
# Initialize cameras from the robot configuration.
|
||||||
|
cameras = make_cameras_from_configs(robot_config.cameras)
|
||||||
|
for cam in cameras.values():
|
||||||
|
cam.connect()
|
||||||
|
|
||||||
|
# Initialize the motors bus using the follower arm configuration.
|
||||||
|
motor_config = robot_config.follower_arms.get("main")
|
||||||
|
if motor_config is None:
|
||||||
|
print("[ERROR] Follower arm 'main' configuration not found.")
|
||||||
|
return
|
||||||
|
motors_bus = FeetechMotorsBus(motor_config)
|
||||||
|
motors_bus.connect()
|
||||||
|
|
||||||
|
# Calibrate the follower arm.
|
||||||
|
calibrate_follower_arm(motors_bus, robot_config.calibration_dir)
|
||||||
|
|
||||||
|
# Create the LeKiwi robot instance.
|
||||||
|
robot = LeKiwi(motors_bus)
|
||||||
|
|
||||||
|
# Define the expected arm motor IDs.
|
||||||
|
arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
|
||||||
|
|
||||||
|
# Disable torque for each arm motor.
|
||||||
|
for motor in arm_motor_ids:
|
||||||
|
motors_bus.write("Torque_Enable", TorqueMode.DISABLED.value, motor)
|
||||||
|
|
||||||
|
# Set up ZeroMQ sockets.
|
||||||
|
context, cmd_socket, video_socket = setup_zmq_sockets(robot_config)
|
||||||
|
|
||||||
|
# Start the camera capture thread.
|
||||||
|
latest_images_dict = {}
|
||||||
|
images_lock = threading.Lock()
|
||||||
|
stop_event = threading.Event()
|
||||||
|
cam_thread = threading.Thread(
|
||||||
|
target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True
|
||||||
|
)
|
||||||
|
cam_thread.start()
|
||||||
|
|
||||||
|
last_cmd_time = time.time()
|
||||||
|
print("LeKiwi robot server started. Waiting for commands...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
loop_start_time = time.time()
|
||||||
|
|
||||||
|
# Process incoming commands (non-blocking).
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
msg = cmd_socket.recv_string(zmq.NOBLOCK)
|
||||||
|
except zmq.Again:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(msg)
|
||||||
|
# Process arm position commands.
|
||||||
|
if "arm_positions" in data:
|
||||||
|
arm_positions = data["arm_positions"]
|
||||||
|
if not isinstance(arm_positions, list):
|
||||||
|
print(f"[ERROR] Invalid arm_positions: {arm_positions}")
|
||||||
|
elif len(arm_positions) < len(arm_motor_ids):
|
||||||
|
print(
|
||||||
|
f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for motor, pos in zip(arm_motor_ids, arm_positions, strict=False):
|
||||||
|
motors_bus.write("Goal_Position", pos, motor)
|
||||||
|
# Process wheel (base) commands.
|
||||||
|
if "raw_velocity" in data:
|
||||||
|
raw_command = data["raw_velocity"]
|
||||||
|
# Expect keys: "left_wheel", "back_wheel", "right_wheel".
|
||||||
|
command_speeds = [
|
||||||
|
int(raw_command.get("left_wheel", 0)),
|
||||||
|
int(raw_command.get("back_wheel", 0)),
|
||||||
|
int(raw_command.get("right_wheel", 0)),
|
||||||
|
]
|
||||||
|
robot.set_velocity(command_speeds)
|
||||||
|
last_cmd_time = time.time()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Parsing message failed: {e}")
|
||||||
|
|
||||||
|
# Watchdog: stop the robot if no command is received for over 0.5 seconds.
|
||||||
|
now = time.time()
|
||||||
|
if now - last_cmd_time > 0.5:
|
||||||
|
robot.stop()
|
||||||
|
last_cmd_time = now
|
||||||
|
|
||||||
|
# Read current wheel speeds from the robot.
|
||||||
|
current_velocity = robot.read_velocity()
|
||||||
|
|
||||||
|
# Read the follower arm state from the motors bus.
|
||||||
|
follower_arm_state = []
|
||||||
|
for motor in arm_motor_ids:
|
||||||
|
try:
|
||||||
|
pos = motors_bus.read("Present_Position", motor)
|
||||||
|
# Convert the position to a float (or use as is if already numeric).
|
||||||
|
follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Reading motor {motor} failed: {e}")
|
||||||
|
|
||||||
|
# Get the latest camera images.
|
||||||
|
with images_lock:
|
||||||
|
images_dict_copy = dict(latest_images_dict)
|
||||||
|
|
||||||
|
# Build the observation dictionary.
|
||||||
|
observation = {
|
||||||
|
"images": images_dict_copy,
|
||||||
|
"present_speed": current_velocity,
|
||||||
|
"follower_arm_state": follower_arm_state,
|
||||||
|
}
|
||||||
|
# Send the observation over the video socket.
|
||||||
|
video_socket.send_string(json.dumps(observation))
|
||||||
|
|
||||||
|
# Ensure a short sleep to avoid overloading the CPU.
|
||||||
|
elapsed = time.time() - loop_start_time
|
||||||
|
time.sleep(
|
||||||
|
max(0.033 - elapsed, 0)
|
||||||
|
) # If robot jitters increase the sleep and monitor cpu load with `top` in cmd
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Shutting down LeKiwi server.")
|
||||||
|
finally:
|
||||||
|
stop_event.set()
|
||||||
|
cam_thread.join()
|
||||||
|
robot.stop()
|
||||||
|
motors_bus.disconnect()
|
||||||
|
cmd_socket.close()
|
||||||
|
video_socket.close()
|
||||||
|
context.term()
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""Contains logic to instantiate a robot, read information from its motors and cameras,
|
"""Contains logic to instantiate a robot, read information from its motors and cameras,
|
||||||
and send orders to its motors.
|
and send orders to its motors.
|
||||||
"""
|
"""
|
||||||
|
@ -44,7 +58,7 @@ class ManipulatorRobot:
|
||||||
# TODO(rcadene): Implement force feedback
|
# TODO(rcadene): Implement force feedback
|
||||||
"""This class allows to control any manipulator robot of various number of motors.
|
"""This class allows to control any manipulator robot of various number of motors.
|
||||||
|
|
||||||
Non exaustive list of robots:
|
Non exhaustive list of robots:
|
||||||
- [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, developed
|
- [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, developed
|
||||||
by Alexander Koch from [Tau Robotics](https://tau-robotics.com)
|
by Alexander Koch from [Tau Robotics](https://tau-robotics.com)
|
||||||
- [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss
|
- [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss
|
||||||
|
@ -55,7 +69,7 @@ class ManipulatorRobot:
|
||||||
robot = ManipulatorRobot(KochRobotConfig())
|
robot = ManipulatorRobot(KochRobotConfig())
|
||||||
```
|
```
|
||||||
|
|
||||||
Example of overwritting motors during instantiation:
|
Example of overwriting motors during instantiation:
|
||||||
```python
|
```python
|
||||||
# Defines how to communicate with the motors of the leader and follower arms
|
# Defines how to communicate with the motors of the leader and follower arms
|
||||||
leader_arms = {
|
leader_arms = {
|
||||||
|
@ -90,7 +104,7 @@ class ManipulatorRobot:
|
||||||
robot = ManipulatorRobot(robot_config)
|
robot = ManipulatorRobot(robot_config)
|
||||||
```
|
```
|
||||||
|
|
||||||
Example of overwritting cameras during instantiation:
|
Example of overwriting cameras during instantiation:
|
||||||
```python
|
```python
|
||||||
# Defines how to communicate with 2 cameras connected to the computer.
|
# Defines how to communicate with 2 cameras connected to the computer.
|
||||||
# Here, the webcam of the laptop and the phone (connected in USB to the laptop)
|
# Here, the webcam of the laptop and the phone (connected in USB to the laptop)
|
||||||
|
@ -229,7 +243,7 @@ class ManipulatorRobot:
|
||||||
|
|
||||||
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
||||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||||
elif self.robot_type in ["so100", "moss"]:
|
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||||
|
|
||||||
# We assume that at connection time, arms are in a rest position, and torque can
|
# We assume that at connection time, arms are in a rest position, and torque can
|
||||||
|
@ -246,7 +260,7 @@ class ManipulatorRobot:
|
||||||
self.set_koch_robot_preset()
|
self.set_koch_robot_preset()
|
||||||
elif self.robot_type == "aloha":
|
elif self.robot_type == "aloha":
|
||||||
self.set_aloha_robot_preset()
|
self.set_aloha_robot_preset()
|
||||||
elif self.robot_type in ["so100", "moss"]:
|
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||||
self.set_so100_robot_preset()
|
self.set_so100_robot_preset()
|
||||||
|
|
||||||
# Enable torque on all motors of the follower arms
|
# Enable torque on all motors of the follower arms
|
||||||
|
@ -299,7 +313,7 @@ class ManipulatorRobot:
|
||||||
|
|
||||||
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
|
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
|
||||||
|
|
||||||
elif self.robot_type in ["so100", "moss"]:
|
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||||
run_arm_manual_calibration,
|
run_arm_manual_calibration,
|
||||||
)
|
)
|
||||||
|
@ -348,7 +362,7 @@ class ManipulatorRobot:
|
||||||
set_operating_mode_(self.follower_arms[name])
|
set_operating_mode_(self.follower_arms[name])
|
||||||
|
|
||||||
# Set better PID values to close the gap between recorded states and actions
|
# Set better PID values to close the gap between recorded states and actions
|
||||||
# TODO(rcadene): Implement an automatic procedure to set optimial PID values for each motor
|
# TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
|
||||||
self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex")
|
self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex")
|
||||||
self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex")
|
self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex")
|
||||||
self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex")
|
self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex")
|
||||||
|
@ -500,7 +514,7 @@ class ManipulatorRobot:
|
||||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||||
|
|
||||||
# Populate output dictionnaries
|
# Populate output dictionaries
|
||||||
obs_dict, action_dict = {}, {}
|
obs_dict, action_dict = {}, {}
|
||||||
obs_dict["observation.state"] = state
|
obs_dict["observation.state"] = state
|
||||||
action_dict["action"] = action
|
action_dict["action"] = action
|
||||||
|
@ -540,7 +554,7 @@ class ManipulatorRobot:
|
||||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||||
|
|
||||||
# Populate output dictionnaries and format to pytorch
|
# Populate output dictionaries and format to pytorch
|
||||||
obs_dict = {}
|
obs_dict = {}
|
||||||
obs_dict["observation.state"] = state
|
obs_dict["observation.state"] = state
|
||||||
for name in self.cameras:
|
for name in self.cameras:
|
||||||
|
|
|
@ -0,0 +1,703 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||||
|
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||||
|
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
|
||||||
|
from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration
|
||||||
|
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||||
|
from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError
|
||||||
|
|
||||||
|
PYNPUT_AVAILABLE = True
|
||||||
|
try:
|
||||||
|
# Only import if there's a valid X server or if we're not on a Pi
|
||||||
|
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
||||||
|
print("No DISPLAY set. Skipping pynput import.")
|
||||||
|
raise ImportError("pynput blocked intentionally due to no display.")
|
||||||
|
|
||||||
|
from pynput import keyboard
|
||||||
|
except ImportError:
|
||||||
|
keyboard = None
|
||||||
|
PYNPUT_AVAILABLE = False
|
||||||
|
except Exception as e:
|
||||||
|
keyboard = None
|
||||||
|
PYNPUT_AVAILABLE = False
|
||||||
|
print(f"Could not import pynput: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class MobileManipulator:
|
||||||
|
"""
|
||||||
|
MobileManipulator is a class for connecting to and controlling a remote mobile manipulator robot.
|
||||||
|
The robot includes a three omniwheel mobile base and a remote follower arm.
|
||||||
|
The leader arm is connected locally (on the laptop) and its joint positions are recorded and then
|
||||||
|
forwarded to the remote follower arm (after applying a safety clamp).
|
||||||
|
In parallel, keyboard teleoperation is used to generate raw velocity commands for the wheels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: LeKiwiRobotConfig):
|
||||||
|
"""
|
||||||
|
Expected keys in config:
|
||||||
|
- ip, port, video_port for the remote connection.
|
||||||
|
- calibration_dir, leader_arms, follower_arms, max_relative_target, etc.
|
||||||
|
"""
|
||||||
|
self.robot_type = config.type
|
||||||
|
self.config = config
|
||||||
|
self.remote_ip = config.ip
|
||||||
|
self.remote_port = config.port
|
||||||
|
self.remote_port_video = config.video_port
|
||||||
|
self.calibration_dir = Path(self.config.calibration_dir)
|
||||||
|
self.logs = {}
|
||||||
|
|
||||||
|
self.teleop_keys = self.config.teleop_keys
|
||||||
|
|
||||||
|
# For teleoperation, the leader arm (local) is used to record the desired arm pose.
|
||||||
|
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
|
||||||
|
|
||||||
|
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
||||||
|
|
||||||
|
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||||
|
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
|
self.last_frames = {}
|
||||||
|
self.last_present_speed = {}
|
||||||
|
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Define three speed levels and a current index
|
||||||
|
self.speed_levels = [
|
||||||
|
{"xy": 0.1, "theta": 30}, # slow
|
||||||
|
{"xy": 0.2, "theta": 60}, # medium
|
||||||
|
{"xy": 0.3, "theta": 90}, # fast
|
||||||
|
]
|
||||||
|
self.speed_index = 0 # Start at slow
|
||||||
|
|
||||||
|
# ZeroMQ context and sockets.
|
||||||
|
self.context = None
|
||||||
|
self.cmd_socket = None
|
||||||
|
self.video_socket = None
|
||||||
|
|
||||||
|
# Keyboard state for base teleoperation.
|
||||||
|
self.running = True
|
||||||
|
self.pressed_keys = {
|
||||||
|
"forward": False,
|
||||||
|
"backward": False,
|
||||||
|
"left": False,
|
||||||
|
"right": False,
|
||||||
|
"rotate_left": False,
|
||||||
|
"rotate_right": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
if PYNPUT_AVAILABLE:
|
||||||
|
print("pynput is available - enabling local keyboard listener.")
|
||||||
|
self.listener = keyboard.Listener(
|
||||||
|
on_press=self.on_press,
|
||||||
|
on_release=self.on_release,
|
||||||
|
)
|
||||||
|
self.listener.start()
|
||||||
|
else:
|
||||||
|
print("pynput not available - skipping local keyboard listener.")
|
||||||
|
self.listener = None
|
||||||
|
|
||||||
|
def get_motor_names(self, arms: dict[str, MotorsBus]) -> list:
|
||||||
|
return [f"{arm}_{motor}" for arm, bus in arms.items() for motor in bus.motors]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_features(self) -> dict:
|
||||||
|
cam_ft = {}
|
||||||
|
for cam_key, cam in self.cameras.items():
|
||||||
|
key = f"observation.images.{cam_key}"
|
||||||
|
cam_ft[key] = {
|
||||||
|
"shape": (cam.height, cam.width, cam.channels),
|
||||||
|
"names": ["height", "width", "channels"],
|
||||||
|
"info": None,
|
||||||
|
}
|
||||||
|
return cam_ft
|
||||||
|
|
||||||
|
@property
|
||||||
|
def motor_features(self) -> dict:
|
||||||
|
follower_arm_names = [
|
||||||
|
"shoulder_pan",
|
||||||
|
"shoulder_lift",
|
||||||
|
"elbow_flex",
|
||||||
|
"wrist_flex",
|
||||||
|
"wrist_roll",
|
||||||
|
"gripper",
|
||||||
|
]
|
||||||
|
observations = ["x_mm", "y_mm", "theta"]
|
||||||
|
combined_names = follower_arm_names + observations
|
||||||
|
return {
|
||||||
|
"action": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(combined_names),),
|
||||||
|
"names": combined_names,
|
||||||
|
},
|
||||||
|
"observation.state": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(combined_names),),
|
||||||
|
"names": combined_names,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def features(self):
|
||||||
|
return {**self.motor_features, **self.camera_features}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_camera(self):
|
||||||
|
return len(self.cameras) > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_cameras(self):
|
||||||
|
return len(self.cameras)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_arms(self):
|
||||||
|
available = []
|
||||||
|
for name in self.leader_arms:
|
||||||
|
available.append(get_arm_id(name, "leader"))
|
||||||
|
for name in self.follower_arms:
|
||||||
|
available.append(get_arm_id(name, "follower"))
|
||||||
|
return available
|
||||||
|
|
||||||
|
def on_press(self, key):
|
||||||
|
try:
|
||||||
|
# Movement
|
||||||
|
if key.char == self.teleop_keys["forward"]:
|
||||||
|
self.pressed_keys["forward"] = True
|
||||||
|
elif key.char == self.teleop_keys["backward"]:
|
||||||
|
self.pressed_keys["backward"] = True
|
||||||
|
elif key.char == self.teleop_keys["left"]:
|
||||||
|
self.pressed_keys["left"] = True
|
||||||
|
elif key.char == self.teleop_keys["right"]:
|
||||||
|
self.pressed_keys["right"] = True
|
||||||
|
elif key.char == self.teleop_keys["rotate_left"]:
|
||||||
|
self.pressed_keys["rotate_left"] = True
|
||||||
|
elif key.char == self.teleop_keys["rotate_right"]:
|
||||||
|
self.pressed_keys["rotate_right"] = True
|
||||||
|
|
||||||
|
# Quit teleoperation
|
||||||
|
elif key.char == self.teleop_keys["quit"]:
|
||||||
|
self.running = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Speed control
|
||||||
|
elif key.char == self.teleop_keys["speed_up"]:
|
||||||
|
self.speed_index = min(self.speed_index + 1, 2)
|
||||||
|
print(f"Speed index increased to {self.speed_index}")
|
||||||
|
elif key.char == self.teleop_keys["speed_down"]:
|
||||||
|
self.speed_index = max(self.speed_index - 1, 0)
|
||||||
|
print(f"Speed index decreased to {self.speed_index}")
|
||||||
|
|
||||||
|
except AttributeError:
|
||||||
|
# e.g., if key is special like Key.esc
|
||||||
|
if key == keyboard.Key.esc:
|
||||||
|
self.running = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def on_release(self, key):
|
||||||
|
try:
|
||||||
|
if hasattr(key, "char"):
|
||||||
|
if key.char == self.teleop_keys["forward"]:
|
||||||
|
self.pressed_keys["forward"] = False
|
||||||
|
elif key.char == self.teleop_keys["backward"]:
|
||||||
|
self.pressed_keys["backward"] = False
|
||||||
|
elif key.char == self.teleop_keys["left"]:
|
||||||
|
self.pressed_keys["left"] = False
|
||||||
|
elif key.char == self.teleop_keys["right"]:
|
||||||
|
self.pressed_keys["right"] = False
|
||||||
|
elif key.char == self.teleop_keys["rotate_left"]:
|
||||||
|
self.pressed_keys["rotate_left"] = False
|
||||||
|
elif key.char == self.teleop_keys["rotate_right"]:
|
||||||
|
self.pressed_keys["rotate_right"] = False
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
if not self.leader_arms:
|
||||||
|
raise ValueError("MobileManipulator has no leader arm to connect.")
|
||||||
|
for name in self.leader_arms:
|
||||||
|
print(f"Connecting {name} leader arm.")
|
||||||
|
self.calibrate_leader()
|
||||||
|
|
||||||
|
# Set up ZeroMQ sockets to communicate with the remote mobile robot.
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.cmd_socket = self.context.socket(zmq.PUSH)
|
||||||
|
connection_string = f"tcp://{self.remote_ip}:{self.remote_port}"
|
||||||
|
self.cmd_socket.connect(connection_string)
|
||||||
|
self.cmd_socket.setsockopt(zmq.CONFLATE, 1)
|
||||||
|
self.video_socket = self.context.socket(zmq.PULL)
|
||||||
|
video_connection = f"tcp://{self.remote_ip}:{self.remote_port_video}"
|
||||||
|
self.video_socket.connect(video_connection)
|
||||||
|
self.video_socket.setsockopt(zmq.CONFLATE, 1)
|
||||||
|
print(
|
||||||
|
f"[INFO] Connected to remote robot at {connection_string} and video stream at {video_connection}."
|
||||||
|
)
|
||||||
|
self.is_connected = True
|
||||||
|
|
||||||
|
def load_or_run_calibration_(self, name, arm, arm_type):
|
||||||
|
arm_id = get_arm_id(name, arm_type)
|
||||||
|
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
||||||
|
|
||||||
|
if arm_calib_path.exists():
|
||||||
|
with open(arm_calib_path) as f:
|
||||||
|
calibration = json.load(f)
|
||||||
|
else:
|
||||||
|
print(f"Missing calibration file '{arm_calib_path}'")
|
||||||
|
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||||
|
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||||
|
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(arm_calib_path, "w") as f:
|
||||||
|
json.dump(calibration, f)
|
||||||
|
|
||||||
|
return calibration
|
||||||
|
|
||||||
|
def calibrate_leader(self):
|
||||||
|
for name, arm in self.leader_arms.items():
|
||||||
|
# Connect the bus
|
||||||
|
arm.connect()
|
||||||
|
|
||||||
|
# Disable torque on all motors
|
||||||
|
for motor_id in arm.motors:
|
||||||
|
arm.write("Torque_Enable", TorqueMode.DISABLED.value, motor_id)
|
||||||
|
|
||||||
|
# Now run calibration
|
||||||
|
calibration = self.load_or_run_calibration_(name, arm, "leader")
|
||||||
|
arm.set_calibration(calibration)
|
||||||
|
|
||||||
|
def calibrate_follower(self):
|
||||||
|
for name, bus in self.follower_arms.items():
|
||||||
|
bus.connect()
|
||||||
|
|
||||||
|
# Disable torque on all motors
|
||||||
|
for motor_id in bus.motors:
|
||||||
|
bus.write("Torque_Enable", 0, motor_id)
|
||||||
|
|
||||||
|
# Then filter out wheels
|
||||||
|
arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
|
||||||
|
if not arm_only_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
original_motors = bus.motors
|
||||||
|
bus.motors = arm_only_dict
|
||||||
|
|
||||||
|
calibration = self.load_or_run_calibration_(name, bus, "follower")
|
||||||
|
bus.set_calibration(calibration)
|
||||||
|
|
||||||
|
bus.motors = original_motors
|
||||||
|
|
||||||
|
def _get_data(self):
|
||||||
|
"""
|
||||||
|
Polls the video socket for up to 15 ms. If data arrives, decode only
|
||||||
|
the *latest* message, returning frames, speed, and arm state. If
|
||||||
|
nothing arrives for any field, use the last known values.
|
||||||
|
"""
|
||||||
|
frames = {}
|
||||||
|
present_speed = {}
|
||||||
|
remote_arm_state_tensor = torch.zeros(6, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Poll up to 15 ms
|
||||||
|
poller = zmq.Poller()
|
||||||
|
poller.register(self.video_socket, zmq.POLLIN)
|
||||||
|
socks = dict(poller.poll(15))
|
||||||
|
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
|
||||||
|
# No new data arrived → reuse ALL old data
|
||||||
|
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||||
|
|
||||||
|
# Drain all messages, keep only the last
|
||||||
|
last_msg = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
obs_string = self.video_socket.recv_string(zmq.NOBLOCK)
|
||||||
|
last_msg = obs_string
|
||||||
|
except zmq.Again:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not last_msg:
|
||||||
|
# No new message → also reuse old
|
||||||
|
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||||
|
|
||||||
|
# Decode only the final message
|
||||||
|
try:
|
||||||
|
observation = json.loads(last_msg)
|
||||||
|
|
||||||
|
images_dict = observation.get("images", {})
|
||||||
|
new_speed = observation.get("present_speed", {})
|
||||||
|
new_arm_state = observation.get("follower_arm_state", None)
|
||||||
|
|
||||||
|
# Convert images
|
||||||
|
for cam_name, image_b64 in images_dict.items():
|
||||||
|
if image_b64:
|
||||||
|
jpg_data = base64.b64decode(image_b64)
|
||||||
|
np_arr = np.frombuffer(jpg_data, dtype=np.uint8)
|
||||||
|
frame_candidate = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
||||||
|
if frame_candidate is not None:
|
||||||
|
frames[cam_name] = frame_candidate
|
||||||
|
|
||||||
|
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
||||||
|
if new_arm_state is not None and frames is not None:
|
||||||
|
self.last_frames = frames
|
||||||
|
|
||||||
|
remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
|
||||||
|
self.last_remote_arm_state = remote_arm_state_tensor
|
||||||
|
|
||||||
|
present_speed = new_speed
|
||||||
|
self.last_present_speed = new_speed
|
||||||
|
else:
|
||||||
|
frames = self.last_frames
|
||||||
|
|
||||||
|
remote_arm_state_tensor = self.last_remote_arm_state
|
||||||
|
|
||||||
|
present_speed = self.last_present_speed
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[DEBUG] Error decoding video message: {e}")
|
||||||
|
# If decode fails, fall back to old data
|
||||||
|
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||||
|
|
||||||
|
return frames, present_speed, remote_arm_state_tensor
|
||||||
|
|
||||||
|
def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
|
||||||
|
state_tensor = torch.zeros(3, dtype=torch.int32)
|
||||||
|
if present_speed:
|
||||||
|
decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
|
||||||
|
if "1" in decoded:
|
||||||
|
state_tensor[0] = decoded["1"]
|
||||||
|
if "2" in decoded:
|
||||||
|
state_tensor[1] = decoded["2"]
|
||||||
|
if "3" in decoded:
|
||||||
|
state_tensor[2] = decoded["3"]
|
||||||
|
return state_tensor
|
||||||
|
|
||||||
|
def teleop_step(
|
||||||
|
self, record_data: bool = False
|
||||||
|
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||||
|
if not self.is_connected:
|
||||||
|
raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
|
||||||
|
|
||||||
|
speed_setting = self.speed_levels[self.speed_index]
|
||||||
|
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
|
||||||
|
theta_speed = speed_setting["theta"] # e.g. 30, 60, or 90
|
||||||
|
|
||||||
|
# Prepare to assign the position of the leader to the follower
|
||||||
|
arm_positions = []
|
||||||
|
for name in self.leader_arms:
|
||||||
|
pos = self.leader_arms[name].read("Present_Position")
|
||||||
|
pos_tensor = torch.from_numpy(pos).float()
|
||||||
|
arm_positions.extend(pos_tensor.tolist())
|
||||||
|
|
||||||
|
y_cmd = 0.0 # m/s forward/backward
|
||||||
|
x_cmd = 0.0 # m/s lateral
|
||||||
|
theta_cmd = 0.0 # deg/s rotation
|
||||||
|
if self.pressed_keys["forward"]:
|
||||||
|
y_cmd += xy_speed
|
||||||
|
if self.pressed_keys["backward"]:
|
||||||
|
y_cmd -= xy_speed
|
||||||
|
if self.pressed_keys["left"]:
|
||||||
|
x_cmd += xy_speed
|
||||||
|
if self.pressed_keys["right"]:
|
||||||
|
x_cmd -= xy_speed
|
||||||
|
if self.pressed_keys["rotate_left"]:
|
||||||
|
theta_cmd += theta_speed
|
||||||
|
if self.pressed_keys["rotate_right"]:
|
||||||
|
theta_cmd -= theta_speed
|
||||||
|
|
||||||
|
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
||||||
|
|
||||||
|
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions}
|
||||||
|
self.cmd_socket.send_string(json.dumps(message))
|
||||||
|
|
||||||
|
if not record_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
obs_dict = self.capture_observation()
|
||||||
|
|
||||||
|
arm_state_tensor = torch.tensor(arm_positions, dtype=torch.float32)
|
||||||
|
|
||||||
|
wheel_velocity_tuple = self.wheel_raw_to_body(wheel_commands)
|
||||||
|
wheel_velocity_mm = (
|
||||||
|
wheel_velocity_tuple[0] * 1000.0,
|
||||||
|
wheel_velocity_tuple[1] * 1000.0,
|
||||||
|
wheel_velocity_tuple[2],
|
||||||
|
)
|
||||||
|
wheel_tensor = torch.tensor(wheel_velocity_mm, dtype=torch.float32)
|
||||||
|
action_tensor = torch.cat([arm_state_tensor, wheel_tensor])
|
||||||
|
action_dict = {"action": action_tensor}
|
||||||
|
|
||||||
|
return obs_dict, action_dict
|
||||||
|
|
||||||
|
def capture_observation(self) -> dict:
|
||||||
|
"""
|
||||||
|
Capture observations from the remote robot: current follower arm positions,
|
||||||
|
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||||
|
and a camera frame.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.")
|
||||||
|
|
||||||
|
frames, present_speed, remote_arm_state_tensor = self._get_data()
|
||||||
|
|
||||||
|
body_state = self.wheel_raw_to_body(present_speed)
|
||||||
|
|
||||||
|
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
||||||
|
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
|
||||||
|
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
||||||
|
|
||||||
|
obs_dict = {"observation.state": combined_state_tensor}
|
||||||
|
|
||||||
|
# Loop over each configured camera
|
||||||
|
for cam_name, cam in self.cameras.items():
|
||||||
|
frame = frames.get(cam_name, None)
|
||||||
|
if frame is None:
|
||||||
|
# Create a black image using the camera's configured width, height, and channels
|
||||||
|
frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
|
||||||
|
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
|
||||||
|
|
||||||
|
return obs_dict
|
||||||
|
|
||||||
|
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not self.is_connected:
|
||||||
|
raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.")
|
||||||
|
|
||||||
|
# Ensure the action tensor has at least 9 elements:
|
||||||
|
# - First 6: arm positions.
|
||||||
|
# - Last 3: base commands.
|
||||||
|
if action.numel() < 9:
|
||||||
|
# Pad with zeros if there are not enough elements.
|
||||||
|
padded = torch.zeros(9, dtype=action.dtype)
|
||||||
|
padded[: action.numel()] = action
|
||||||
|
action = padded
|
||||||
|
|
||||||
|
# Extract arm and base actions.
|
||||||
|
arm_actions = action[:6].flatten()
|
||||||
|
base_actions = action[6:].flatten()
|
||||||
|
|
||||||
|
x_cmd_mm = base_actions[0].item() # mm/s
|
||||||
|
y_cmd_mm = base_actions[1].item() # mm/s
|
||||||
|
theta_cmd = base_actions[2].item() # deg/s
|
||||||
|
|
||||||
|
# Convert mm/s to m/s for the kinematics calculations.
|
||||||
|
x_cmd = x_cmd_mm / 1000.0 # m/s
|
||||||
|
y_cmd = y_cmd_mm / 1000.0 # m/s
|
||||||
|
|
||||||
|
# Compute wheel commands from body commands.
|
||||||
|
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
||||||
|
|
||||||
|
arm_positions_list = arm_actions.tolist()
|
||||||
|
|
||||||
|
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions_list}
|
||||||
|
self.cmd_socket.send_string(json.dumps(message))
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
def print_logs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
if not self.is_connected:
|
||||||
|
raise RobotDeviceNotConnectedError("Not connected.")
|
||||||
|
if self.cmd_socket:
|
||||||
|
stop_cmd = {
|
||||||
|
"raw_velocity": {"left_wheel": 0, "back_wheel": 0, "right_wheel": 0},
|
||||||
|
"arm_positions": {},
|
||||||
|
}
|
||||||
|
self.cmd_socket.send_string(json.dumps(stop_cmd))
|
||||||
|
self.cmd_socket.close()
|
||||||
|
if self.video_socket:
|
||||||
|
self.video_socket.close()
|
||||||
|
if self.context:
|
||||||
|
self.context.term()
|
||||||
|
if PYNPUT_AVAILABLE:
|
||||||
|
self.listener.stop()
|
||||||
|
self.is_connected = False
|
||||||
|
print("[INFO] Disconnected from remote robot.")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if getattr(self, "is_connected", False):
|
||||||
|
self.disconnect()
|
||||||
|
if PYNPUT_AVAILABLE:
|
||||||
|
self.listener.stop()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def degps_to_raw(degps: float) -> int:
|
||||||
|
steps_per_deg = 4096.0 / 360.0
|
||||||
|
speed_in_steps = abs(degps) * steps_per_deg
|
||||||
|
speed_int = int(round(speed_in_steps))
|
||||||
|
if speed_int > 0x7FFF:
|
||||||
|
speed_int = 0x7FFF
|
||||||
|
if degps < 0:
|
||||||
|
return speed_int | 0x8000
|
||||||
|
else:
|
||||||
|
return speed_int & 0x7FFF
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def raw_to_degps(raw_speed: int) -> float:
|
||||||
|
steps_per_deg = 4096.0 / 360.0
|
||||||
|
magnitude = raw_speed & 0x7FFF
|
||||||
|
degps = magnitude / steps_per_deg
|
||||||
|
if raw_speed & 0x8000:
|
||||||
|
degps = -degps
|
||||||
|
return degps
|
||||||
|
|
||||||
|
def body_to_wheel_raw(
|
||||||
|
self,
|
||||||
|
x_cmd: float,
|
||||||
|
y_cmd: float,
|
||||||
|
theta_cmd: float,
|
||||||
|
wheel_radius: float = 0.05,
|
||||||
|
base_radius: float = 0.125,
|
||||||
|
max_raw: int = 3000,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Convert desired body-frame velocities into wheel raw commands.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
x_cmd : Linear velocity in x (m/s).
|
||||||
|
y_cmd : Linear velocity in y (m/s).
|
||||||
|
theta_cmd : Rotational velocity (deg/s).
|
||||||
|
wheel_radius: Radius of each wheel (meters).
|
||||||
|
base_radius : Distance from the center of rotation to each wheel (meters).
|
||||||
|
max_raw : Maximum allowed raw command (ticks) per wheel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with wheel raw commands:
|
||||||
|
{"left_wheel": value, "back_wheel": value, "right_wheel": value}.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Internally, the method converts theta_cmd to rad/s for the kinematics.
|
||||||
|
- The raw command is computed from the wheels angular speed in deg/s
|
||||||
|
using degps_to_raw(). If any command exceeds max_raw, all commands
|
||||||
|
are scaled down proportionally.
|
||||||
|
"""
|
||||||
|
# Convert rotational velocity from deg/s to rad/s.
|
||||||
|
theta_rad = theta_cmd * (np.pi / 180.0)
|
||||||
|
# Create the body velocity vector [x, y, theta_rad].
|
||||||
|
velocity_vector = np.array([x_cmd, y_cmd, theta_rad])
|
||||||
|
|
||||||
|
# Define the wheel mounting angles (defined from y axis cw)
|
||||||
|
angles = np.radians(np.array([300, 180, 60]))
|
||||||
|
# Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed.
|
||||||
|
# The third column (base_radius) accounts for the effect of rotation.
|
||||||
|
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
||||||
|
|
||||||
|
# Compute each wheel’s linear speed (m/s) and then its angular speed (rad/s).
|
||||||
|
wheel_linear_speeds = m.dot(velocity_vector)
|
||||||
|
wheel_angular_speeds = wheel_linear_speeds / wheel_radius
|
||||||
|
|
||||||
|
# Convert wheel angular speeds from rad/s to deg/s.
|
||||||
|
wheel_degps = wheel_angular_speeds * (180.0 / np.pi)
|
||||||
|
|
||||||
|
# Scaling
|
||||||
|
steps_per_deg = 4096.0 / 360.0
|
||||||
|
raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps]
|
||||||
|
max_raw_computed = max(raw_floats)
|
||||||
|
if max_raw_computed > max_raw:
|
||||||
|
scale = max_raw / max_raw_computed
|
||||||
|
wheel_degps = wheel_degps * scale
|
||||||
|
|
||||||
|
# Convert each wheel’s angular speed (deg/s) to a raw integer.
|
||||||
|
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
|
||||||
|
|
||||||
|
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
|
||||||
|
|
||||||
|
def wheel_raw_to_body(
|
||||||
|
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125
|
||||||
|
) -> tuple:
|
||||||
|
"""
|
||||||
|
Convert wheel raw command feedback back into body-frame velocities.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
wheel_raw : Dictionary with raw wheel commands (keys: "left_wheel", "back_wheel", "right_wheel").
|
||||||
|
wheel_radius: Radius of each wheel (meters).
|
||||||
|
base_radius : Distance from the robot center to each wheel (meters).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple (x_cmd, y_cmd, theta_cmd) where:
|
||||||
|
x_cmd : Linear velocity in x (m/s).
|
||||||
|
y_cmd : Linear velocity in y (m/s).
|
||||||
|
theta_cmd : Rotational velocity in deg/s.
|
||||||
|
"""
|
||||||
|
# Extract the raw values in order.
|
||||||
|
raw_list = [
|
||||||
|
int(wheel_raw.get("left_wheel", 0)),
|
||||||
|
int(wheel_raw.get("back_wheel", 0)),
|
||||||
|
int(wheel_raw.get("right_wheel", 0)),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert each raw command back to an angular speed in deg/s.
|
||||||
|
wheel_degps = np.array([MobileManipulator.raw_to_degps(r) for r in raw_list])
|
||||||
|
# Convert from deg/s to rad/s.
|
||||||
|
wheel_radps = wheel_degps * (np.pi / 180.0)
|
||||||
|
# Compute each wheel’s linear speed (m/s) from its angular speed.
|
||||||
|
wheel_linear_speeds = wheel_radps * wheel_radius
|
||||||
|
|
||||||
|
# Define the wheel mounting angles (defined from y axis cw)
|
||||||
|
angles = np.radians(np.array([300, 180, 60]))
|
||||||
|
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
||||||
|
|
||||||
|
# Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds.
|
||||||
|
m_inv = np.linalg.inv(m)
|
||||||
|
velocity_vector = m_inv.dot(wheel_linear_speeds)
|
||||||
|
x_cmd, y_cmd, theta_rad = velocity_vector
|
||||||
|
theta_cmd = theta_rad * (180.0 / np.pi)
|
||||||
|
return (x_cmd, y_cmd, theta_cmd)
|
||||||
|
|
||||||
|
|
||||||
|
class LeKiwi:
|
||||||
|
def __init__(self, motor_bus):
|
||||||
|
"""
|
||||||
|
Initializes the LeKiwi with Feetech motors bus.
|
||||||
|
"""
|
||||||
|
self.motor_bus = motor_bus
|
||||||
|
self.motor_ids = ["left_wheel", "back_wheel", "right_wheel"]
|
||||||
|
|
||||||
|
# Initialize motors in velocity mode.
|
||||||
|
self.motor_bus.write("Lock", 0)
|
||||||
|
self.motor_bus.write("Mode", [1, 1, 1], self.motor_ids)
|
||||||
|
self.motor_bus.write("Lock", 1)
|
||||||
|
print("Motors set to velocity mode.")
|
||||||
|
|
||||||
|
def read_velocity(self):
|
||||||
|
"""
|
||||||
|
Reads the raw speeds for all wheels. Returns a dictionary with motor names:
|
||||||
|
"""
|
||||||
|
raw_speeds = self.motor_bus.read("Present_Speed", self.motor_ids)
|
||||||
|
return {
|
||||||
|
"left_wheel": int(raw_speeds[0]),
|
||||||
|
"back_wheel": int(raw_speeds[1]),
|
||||||
|
"right_wheel": int(raw_speeds[2]),
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_velocity(self, command_speeds):
|
||||||
|
"""
|
||||||
|
Sends raw velocity commands (16-bit encoded values) directly to the motor bus.
|
||||||
|
The order of speeds must correspond to self.motor_ids.
|
||||||
|
"""
|
||||||
|
self.motor_bus.write("Goal_Speed", command_speeds, self.motor_ids)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stops the robot by setting all motor speeds to zero."""
|
||||||
|
self.motor_bus.write("Goal_Speed", [0, 0, 0], self.motor_ids)
|
||||||
|
print("Motors stopped.")
|
|
@ -108,7 +108,7 @@ class StretchRobot(StretchAPI):
|
||||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||||
|
|
||||||
# Populate output dictionnaries
|
# Populate output dictionaries
|
||||||
obs_dict, action_dict = {}, {}
|
obs_dict, action_dict = {}, {}
|
||||||
obs_dict["observation.state"] = state
|
obs_dict["observation.state"] = state
|
||||||
action_dict["action"] = action
|
action_dict["action"] = action
|
||||||
|
@ -153,7 +153,7 @@ class StretchRobot(StretchAPI):
|
||||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||||
|
|
||||||
# Populate output dictionnaries
|
# Populate output dictionaries
|
||||||
obs_dict = {}
|
obs_dict = {}
|
||||||
obs_dict["observation.state"] = state
|
obs_dict["observation.state"] = state
|
||||||
for name in self.cameras:
|
for name in self.cameras:
|
||||||
|
|
|
@ -1,9 +1,24 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from lerobot.common.robot_devices.robots.configs import (
|
from lerobot.common.robot_devices.robots.configs import (
|
||||||
AlohaRobotConfig,
|
AlohaRobotConfig,
|
||||||
KochBimanualRobotConfig,
|
KochBimanualRobotConfig,
|
||||||
KochRobotConfig,
|
KochRobotConfig,
|
||||||
|
LeKiwiRobotConfig,
|
||||||
ManipulatorRobotConfig,
|
ManipulatorRobotConfig,
|
||||||
MossRobotConfig,
|
MossRobotConfig,
|
||||||
RobotConfig,
|
RobotConfig,
|
||||||
|
@ -45,6 +60,8 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
||||||
return So100RobotConfig(**kwargs)
|
return So100RobotConfig(**kwargs)
|
||||||
elif robot_type == "stretch":
|
elif robot_type == "stretch":
|
||||||
return StretchRobotConfig(**kwargs)
|
return StretchRobotConfig(**kwargs)
|
||||||
|
elif robot_type == "lekiwi":
|
||||||
|
return LeKiwiRobotConfig(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
||||||
|
|
||||||
|
@ -54,6 +71,10 @@ def make_robot_from_config(config: RobotConfig):
|
||||||
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
|
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
|
||||||
|
|
||||||
return ManipulatorRobot(config)
|
return ManipulatorRobot(config)
|
||||||
|
elif isinstance(config, LeKiwiRobotConfig):
|
||||||
|
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
|
||||||
|
|
||||||
|
return MobileManipulator(config)
|
||||||
else:
|
else:
|
||||||
from lerobot.common.robot_devices.robots.stretch import StretchRobot
|
from lerobot.common.robot_devices.robots.stretch import StretchRobot
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue