Merge branch 'main' into 2025_02_20_add_dexvla
This commit is contained in:
commit
0740ea8290
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Misc
|
# Misc
|
||||||
.git
|
.git
|
||||||
tmp
|
tmp
|
||||||
|
@ -59,7 +73,7 @@ pip-log.txt
|
||||||
pip-delete-this-directory.txt
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
# Unit test / coverage reports
|
# Unit test / coverage reports
|
||||||
!tests/data
|
!tests/artifacts
|
||||||
htmlcov/
|
htmlcov/
|
||||||
.tox/
|
.tox/
|
||||||
.nox/
|
.nox/
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
*.memmap filter=lfs diff=lfs merge=lfs -text
|
*.memmap filter=lfs diff=lfs merge=lfs -text
|
||||||
*.stl filter=lfs diff=lfs merge=lfs -text
|
*.stl filter=lfs diff=lfs merge=lfs -text
|
||||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
name: "\U0001F41B Bug Report"
|
name: "\U0001F41B Bug Report"
|
||||||
description: Submit a bug report to help us improve LeRobot
|
description: Submit a bug report to help us improve LeRobot
|
||||||
body:
|
body:
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Inspired by
|
# Inspired by
|
||||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml
|
# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml
|
||||||
name: Builds
|
name: Builds
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Inspired by
|
# Inspired by
|
||||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml
|
# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml
|
||||||
name: Nightly
|
name: Nightly
|
||||||
|
|
|
@ -1,161 +0,0 @@
|
||||||
# Adapted from https://github.com/huggingface/diffusers/blob/main/.github/workflows/pr_style_bot.yml
|
|
||||||
name: PR Style Bot
|
|
||||||
|
|
||||||
on:
|
|
||||||
issue_comment:
|
|
||||||
types: [created]
|
|
||||||
|
|
||||||
permissions: {}
|
|
||||||
|
|
||||||
env:
|
|
||||||
PYTHON_VERSION: "3.10"
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
check-permissions:
|
|
||||||
if: >
|
|
||||||
contains(github.event.comment.body, '@bot /style') &&
|
|
||||||
github.event.issue.pull_request != null
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
outputs:
|
|
||||||
is_authorized: ${{ steps.check_user_permission.outputs.has_permission }}
|
|
||||||
steps:
|
|
||||||
- name: Check user permission
|
|
||||||
id: check_user_permission
|
|
||||||
uses: actions/github-script@v6
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
const comment_user = context.payload.comment.user.login;
|
|
||||||
const { data: permission } = await github.rest.repos.getCollaboratorPermissionLevel({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
username: comment_user
|
|
||||||
});
|
|
||||||
|
|
||||||
const authorized =
|
|
||||||
permission.permission === 'admin' ||
|
|
||||||
permission.permission === 'write';
|
|
||||||
|
|
||||||
console.log(
|
|
||||||
`User ${comment_user} has permission level: ${permission.permission}, ` +
|
|
||||||
`authorized: ${authorized} (admins & maintainers allowed)`
|
|
||||||
);
|
|
||||||
|
|
||||||
core.setOutput('has_permission', authorized);
|
|
||||||
|
|
||||||
run-style-bot:
|
|
||||||
needs: check-permissions
|
|
||||||
if: needs.check-permissions.outputs.is_authorized == 'true'
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
pull-requests: write
|
|
||||||
steps:
|
|
||||||
- name: Extract PR details
|
|
||||||
id: pr_info
|
|
||||||
uses: actions/github-script@v6
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
const prNumber = context.payload.issue.number;
|
|
||||||
const { data: pr } = await github.rest.pulls.get({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
pull_number: prNumber
|
|
||||||
});
|
|
||||||
|
|
||||||
// We capture both the branch ref and the "full_name" of the head repo
|
|
||||||
// so that we can check out the correct repository & branch (including forks).
|
|
||||||
core.setOutput("prNumber", prNumber);
|
|
||||||
core.setOutput("headRef", pr.head.ref);
|
|
||||||
core.setOutput("headRepoFullName", pr.head.repo.full_name);
|
|
||||||
|
|
||||||
- name: Check out PR branch
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
env:
|
|
||||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
|
||||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
|
||||||
with:
|
|
||||||
persist-credentials: true
|
|
||||||
# Instead of checking out the base repo, use the contributor's repo name
|
|
||||||
repository: ${{ env.HEADREPOFULLNAME }}
|
|
||||||
ref: ${{ env.HEADREF }}
|
|
||||||
# You may need fetch-depth: 0 for being able to push
|
|
||||||
fetch-depth: 0
|
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Debug
|
|
||||||
env:
|
|
||||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
|
||||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
|
||||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
|
||||||
run: |
|
|
||||||
echo "PR number: ${PRNUMBER}"
|
|
||||||
echo "Head Ref: ${HEADREF}"
|
|
||||||
echo "Head Repo Full Name: ${HEADREPOFULLNAME}"
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: ${{ env.PYTHON_VERSION }}
|
|
||||||
|
|
||||||
- name: Get Ruff Version from pre-commit-config.yaml
|
|
||||||
id: get-ruff-version
|
|
||||||
run: |
|
|
||||||
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
|
|
||||||
echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
- name: Install Ruff
|
|
||||||
env:
|
|
||||||
RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
|
|
||||||
run: python -m pip install "ruff==${RUFF_VERSION}"
|
|
||||||
|
|
||||||
- name: Ruff check
|
|
||||||
run: ruff check --fix
|
|
||||||
|
|
||||||
- name: Ruff format
|
|
||||||
run: ruff format
|
|
||||||
|
|
||||||
- name: Commit and push changes
|
|
||||||
id: commit_and_push
|
|
||||||
env:
|
|
||||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
|
||||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
|
||||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
run: |
|
|
||||||
echo "HEADREPOFULLNAME: ${HEADREPOFULLNAME}, HEADREF: ${HEADREF}"
|
|
||||||
# Configure git with the Actions bot user
|
|
||||||
git config user.name "github-actions[bot]"
|
|
||||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
|
||||||
git config --local lfs.https://github.com/.locksverify false
|
|
||||||
|
|
||||||
# Make sure your 'origin' remote is set to the contributor's fork
|
|
||||||
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${HEADREPOFULLNAME}.git"
|
|
||||||
|
|
||||||
# If there are changes after running style/quality, commit them
|
|
||||||
if [ -n "$(git status --porcelain)" ]; then
|
|
||||||
git add .
|
|
||||||
git commit -m "Apply style fixes"
|
|
||||||
# Push to the original contributor's forked branch
|
|
||||||
git push origin HEAD:${HEADREF}
|
|
||||||
echo "changes_pushed=true" >> $GITHUB_OUTPUT
|
|
||||||
else
|
|
||||||
echo "No changes to commit."
|
|
||||||
echo "changes_pushed=false" >> $GITHUB_OUTPUT
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Comment on PR with workflow run link
|
|
||||||
if: steps.commit_and_push.outputs.changes_pushed == 'true'
|
|
||||||
uses: actions/github-script@v6
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
const prNumber = parseInt(process.env.prNumber, 10);
|
|
||||||
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
|
|
||||||
|
|
||||||
await github.rest.issues.createComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: prNumber,
|
|
||||||
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
|
|
||||||
});
|
|
||||||
env:
|
|
||||||
prNumber: ${{ steps.pr_info.outputs.prNumber }}
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
name: Quality
|
name: Quality
|
||||||
|
|
||||||
on:
|
on:
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Inspired by
|
# Inspired by
|
||||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml
|
# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml
|
||||||
name: Test Dockerfiles
|
name: Test Dockerfiles
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
name: Tests
|
name: Tests
|
||||||
|
|
||||||
on:
|
on:
|
||||||
|
@ -112,7 +126,7 @@ jobs:
|
||||||
# portaudio19-dev is needed to install pyaudio
|
# portaudio19-dev is needed to install pyaudio
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update && \
|
sudo apt-get update && \
|
||||||
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||||
|
|
||||||
- name: Install uv and python
|
- name: Install uv and python
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@v5
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
logs
|
logs
|
||||||
tmp
|
tmp
|
||||||
|
@ -64,7 +78,7 @@ pip-log.txt
|
||||||
pip-delete-this-directory.txt
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
# Unit test / coverage reports
|
# Unit test / coverage reports
|
||||||
!tests/data
|
!tests/artifacts
|
||||||
htmlcov/
|
htmlcov/
|
||||||
.tox/
|
.tox/
|
||||||
.nox/
|
.nox/
|
||||||
|
|
|
@ -1,7 +1,28 @@
|
||||||
exclude: ^(tests/data)
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
exclude: "tests/artifacts/.*\\.safetensors$"
|
||||||
default_language_version:
|
default_language_version:
|
||||||
python: python3.10
|
python: python3.10
|
||||||
repos:
|
repos:
|
||||||
|
##### Meta #####
|
||||||
|
- repo: meta
|
||||||
|
hooks:
|
||||||
|
- id: check-useless-excludes
|
||||||
|
- id: check-hooks-apply
|
||||||
|
|
||||||
|
|
||||||
##### Style / Misc. #####
|
##### Style / Misc. #####
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v5.0.0
|
rev: v5.0.0
|
||||||
|
@ -14,31 +35,37 @@ repos:
|
||||||
- id: check-toml
|
- id: check-toml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
|
|
||||||
- repo: https://github.com/crate-ci/typos
|
- repo: https://github.com/crate-ci/typos
|
||||||
rev: v1.30.0
|
rev: v1.30.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: typos
|
- id: typos
|
||||||
args: [--force-exclude]
|
args: [--force-exclude]
|
||||||
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v3.19.1
|
rev: v3.19.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.9.9
|
rev: v0.9.10
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|
||||||
|
|
||||||
##### Security #####
|
##### Security #####
|
||||||
- repo: https://github.com/gitleaks/gitleaks
|
- repo: https://github.com/gitleaks/gitleaks
|
||||||
rev: v8.24.0
|
rev: v8.24.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: gitleaks
|
- id: gitleaks
|
||||||
|
|
||||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||||
rev: v1.4.1
|
rev: v1.4.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: zizmor
|
- id: zizmor
|
||||||
|
|
||||||
- repo: https://github.com/PyCQA/bandit
|
- repo: https://github.com/PyCQA/bandit
|
||||||
rev: 1.8.3
|
rev: 1.8.3
|
||||||
hooks:
|
hooks:
|
||||||
|
|
|
@ -291,7 +291,7 @@ sudo apt-get install git-lfs
|
||||||
git lfs install
|
git lfs install
|
||||||
```
|
```
|
||||||
|
|
||||||
Pull artifacts if they're not in [tests/data](tests/data)
|
Pull artifacts if they're not in [tests/artifacts](tests/artifacts)
|
||||||
```bash
|
```bash
|
||||||
git lfs pull
|
git lfs pull
|
||||||
```
|
```
|
||||||
|
|
32
Makefile
32
Makefile
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
.PHONY: tests
|
.PHONY: tests
|
||||||
|
|
||||||
PYTHON_PATH := $(shell which python)
|
PYTHON_PATH := $(shell which python)
|
||||||
|
@ -33,6 +47,7 @@ test-act-ete-train:
|
||||||
--policy.dim_model=64 \
|
--policy.dim_model=64 \
|
||||||
--policy.n_action_steps=20 \
|
--policy.n_action_steps=20 \
|
||||||
--policy.chunk_size=20 \
|
--policy.chunk_size=20 \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=aloha \
|
--env.type=aloha \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||||
|
@ -47,7 +62,6 @@ test-act-ete-train:
|
||||||
--save_checkpoint=true \
|
--save_checkpoint=true \
|
||||||
--log_freq=1 \
|
--log_freq=1 \
|
||||||
--wandb.enable=false \
|
--wandb.enable=false \
|
||||||
--device=$(DEVICE) \
|
|
||||||
--output_dir=tests/outputs/act/
|
--output_dir=tests/outputs/act/
|
||||||
|
|
||||||
test-act-ete-train-resume:
|
test-act-ete-train-resume:
|
||||||
|
@ -58,11 +72,11 @@ test-act-ete-train-resume:
|
||||||
test-act-ete-eval:
|
test-act-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
|
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=aloha \
|
--env.type=aloha \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1
|
||||||
--device=$(DEVICE)
|
|
||||||
|
|
||||||
test-diffusion-ete-train:
|
test-diffusion-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
|
@ -70,6 +84,7 @@ test-diffusion-ete-train:
|
||||||
--policy.down_dims='[64,128,256]' \
|
--policy.down_dims='[64,128,256]' \
|
||||||
--policy.diffusion_step_embed_dim=32 \
|
--policy.diffusion_step_embed_dim=32 \
|
||||||
--policy.num_inference_steps=10 \
|
--policy.num_inference_steps=10 \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=pusht \
|
--env.type=pusht \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--dataset.repo_id=lerobot/pusht \
|
--dataset.repo_id=lerobot/pusht \
|
||||||
|
@ -84,21 +99,21 @@ test-diffusion-ete-train:
|
||||||
--save_freq=2 \
|
--save_freq=2 \
|
||||||
--log_freq=1 \
|
--log_freq=1 \
|
||||||
--wandb.enable=false \
|
--wandb.enable=false \
|
||||||
--device=$(DEVICE) \
|
|
||||||
--output_dir=tests/outputs/diffusion/
|
--output_dir=tests/outputs/diffusion/
|
||||||
|
|
||||||
test-diffusion-ete-eval:
|
test-diffusion-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=pusht \
|
--env.type=pusht \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1
|
||||||
--device=$(DEVICE)
|
|
||||||
|
|
||||||
test-tdmpc-ete-train:
|
test-tdmpc-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
--policy.type=tdmpc \
|
--policy.type=tdmpc \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=xarm \
|
--env.type=xarm \
|
||||||
--env.task=XarmLift-v0 \
|
--env.task=XarmLift-v0 \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
|
@ -114,15 +129,14 @@ test-tdmpc-ete-train:
|
||||||
--save_freq=2 \
|
--save_freq=2 \
|
||||||
--log_freq=1 \
|
--log_freq=1 \
|
||||||
--wandb.enable=false \
|
--wandb.enable=false \
|
||||||
--device=$(DEVICE) \
|
|
||||||
--output_dir=tests/outputs/tdmpc/
|
--output_dir=tests/outputs/tdmpc/
|
||||||
|
|
||||||
test-tdmpc-ete-eval:
|
test-tdmpc-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=xarm \
|
--env.type=xarm \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--env.task=XarmLift-v0 \
|
--env.task=XarmLift-v0 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1
|
||||||
--device=$(DEVICE)
|
|
||||||
|
|
|
@ -232,8 +232,8 @@ python lerobot/scripts/eval.py \
|
||||||
--env.type=pusht \
|
--env.type=pusht \
|
||||||
--eval.batch_size=10 \
|
--eval.batch_size=10 \
|
||||||
--eval.n_episodes=10 \
|
--eval.n_episodes=10 \
|
||||||
--use_amp=false \
|
--policy.use_amp=false \
|
||||||
--device=cuda
|
--policy.device=cuda
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
||||||
|
@ -384,3 +384,6 @@ Additionally, if you are using any of the particular policy architecture, pretra
|
||||||
year={2024}
|
year={2024}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
## Star History
|
||||||
|
|
||||||
|
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||||
|
|
|
@ -67,7 +67,7 @@ def parse_int_or_none(value) -> int | None:
|
||||||
def check_datasets_formats(repo_ids: list) -> None:
|
def check_datasets_formats(repo_ids: list) -> None:
|
||||||
for repo_id in repo_ids:
|
for repo_id in repo_ids:
|
||||||
dataset = LeRobotDataset(repo_id)
|
dataset = LeRobotDataset(repo_id)
|
||||||
if dataset.video:
|
if len(dataset.meta.video_keys) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -99,22 +99,22 @@ Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem5
|
||||||
```
|
```
|
||||||
Finding all available ports for the MotorBus.
|
Finding all available ports for the MotorBus.
|
||||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
Remove the usb cable from your MotorsBus and press Enter when done.
|
||||||
|
|
||||||
[...Disconnect leader arm and press Enter...]
|
[...Disconnect leader arm and press Enter...]
|
||||||
|
|
||||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
The port of this MotorsBus is /dev/tty.usbmodem575E0031751
|
||||||
Reconnect the usb cable.
|
Reconnect the usb cable.
|
||||||
```
|
```
|
||||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||||
```
|
```
|
||||||
Finding all available ports for the MotorBus.
|
Finding all available ports for the MotorBus.
|
||||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
Remove the usb cable from your MotorsBus and press Enter when done.
|
||||||
|
|
||||||
[...Disconnect follower arm and press Enter...]
|
[...Disconnect follower arm and press Enter...]
|
||||||
|
|
||||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
|
||||||
Reconnect the usb cable.
|
Reconnect the usb cable.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -454,8 +454,8 @@ Next, you'll need to calibrate your SO-100 robot to ensure that the leader and f
|
||||||
|
|
||||||
You will need to move the follower arm to these positions sequentially:
|
You will need to move the follower arm to these positions sequentially:
|
||||||
|
|
||||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|---|---|---|
|
| ------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
| <img src="../media/so100/follower_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/so100/follower_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/so100/follower_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
| <img src="../media/so100/follower_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/so100/follower_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/so100/follower_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
Make sure both arms are connected and run this script to launch manual calibration:
|
Make sure both arms are connected and run this script to launch manual calibration:
|
||||||
|
@ -470,8 +470,8 @@ python lerobot/scripts/control_robot.py \
|
||||||
#### b. Manual calibration of leader arm
|
#### b. Manual calibration of leader arm
|
||||||
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||||
|
|
||||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|---|---|---|
|
| ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
Run this script to launch manual calibration:
|
Run this script to launch manual calibration:
|
||||||
|
@ -571,18 +571,25 @@ python lerobot/scripts/train.py \
|
||||||
--policy.type=act \
|
--policy.type=act \
|
||||||
--output_dir=outputs/train/act_so100_test \
|
--output_dir=outputs/train/act_so100_test \
|
||||||
--job_name=act_so100_test \
|
--job_name=act_so100_test \
|
||||||
--device=cuda \
|
--policy.device=cuda \
|
||||||
--wandb.enable=true
|
--wandb.enable=true
|
||||||
```
|
```
|
||||||
|
|
||||||
Let's explain it:
|
Let's explain it:
|
||||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`.
|
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`.
|
||||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||||
|
|
||||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||||
|
|
||||||
|
To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so100_test` policy:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
--config_path=outputs/train/act_so100_test/checkpoints/last/pretrained_model/train_config.json \
|
||||||
|
--resume=true
|
||||||
|
```
|
||||||
|
|
||||||
## K. Evaluate your policy
|
## K. Evaluate your policy
|
||||||
|
|
||||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||||
|
|
|
@ -366,8 +366,8 @@ Now we have to calibrate the leader arm and the follower arm. The wheel motors d
|
||||||
|
|
||||||
You will need to move the follower arm to these positions sequentially:
|
You will need to move the follower arm to these positions sequentially:
|
||||||
|
|
||||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|---|---|---|
|
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| <img src="../media/lekiwi/mobile_calib_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
| <img src="../media/lekiwi/mobile_calib_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
Make sure the arm is connected to the Raspberry Pi and run this script (on the Raspberry Pi) to launch manual calibration:
|
Make sure the arm is connected to the Raspberry Pi and run this script (on the Raspberry Pi) to launch manual calibration:
|
||||||
|
@ -385,8 +385,8 @@ If you have the **wired** LeKiwi version please run all commands including this
|
||||||
### Calibrate leader arm
|
### Calibrate leader arm
|
||||||
Then to calibrate the leader arm (which is attached to the laptop/pc). You will need to move the leader arm to these positions sequentially:
|
Then to calibrate the leader arm (which is attached to the laptop/pc). You will need to move the leader arm to these positions sequentially:
|
||||||
|
|
||||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|---|---|---|
|
| ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
Run this script (on your laptop/pc) to launch manual calibration:
|
Run this script (on your laptop/pc) to launch manual calibration:
|
||||||
|
@ -416,22 +416,22 @@ python lerobot/scripts/control_robot.py \
|
||||||
|
|
||||||
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
|
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
|
||||||
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
|
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
|
||||||
|------------|-------------------|-----------------------|
|
| ---------- | ------------------ | ---------------------- |
|
||||||
| Fast | 0.4 | 90 |
|
| Fast | 0.4 | 90 |
|
||||||
| Medium | 0.25 | 60 |
|
| Medium | 0.25 | 60 |
|
||||||
| Slow | 0.1 | 30 |
|
| Slow | 0.1 | 30 |
|
||||||
|
|
||||||
|
|
||||||
| Key | Action |
|
| Key | Action |
|
||||||
|------|--------------------------------|
|
| --- | -------------- |
|
||||||
| W | Move forward |
|
| W | Move forward |
|
||||||
| A | Move left |
|
| A | Move left |
|
||||||
| S | Move backward |
|
| S | Move backward |
|
||||||
| D | Move right |
|
| D | Move right |
|
||||||
| Z | Turn left |
|
| Z | Turn left |
|
||||||
| X | Turn right |
|
| X | Turn right |
|
||||||
| R | Increase speed |
|
| R | Increase speed |
|
||||||
| F | Decrease speed |
|
| F | Decrease speed |
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> If you use a different keyboard you can change the keys for each command in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
|
> If you use a different keyboard you can change the keys for each command in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
|
||||||
|
@ -549,14 +549,14 @@ python lerobot/scripts/train.py \
|
||||||
--policy.type=act \
|
--policy.type=act \
|
||||||
--output_dir=outputs/train/act_lekiwi_test \
|
--output_dir=outputs/train/act_lekiwi_test \
|
||||||
--job_name=act_lekiwi_test \
|
--job_name=act_lekiwi_test \
|
||||||
--device=cuda \
|
--policy.device=cuda \
|
||||||
--wandb.enable=true
|
--wandb.enable=true
|
||||||
```
|
```
|
||||||
|
|
||||||
Let's explain it:
|
Let's explain it:
|
||||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/lekiwi_test`.
|
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/lekiwi_test`.
|
||||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||||
|
|
||||||
Training should take several hours. You will find checkpoints in `outputs/train/act_lekiwi_test/checkpoints`.
|
Training should take several hours. You will find checkpoints in `outputs/train/act_lekiwi_test/checkpoints`.
|
||||||
|
|
|
@ -176,8 +176,8 @@ Next, you'll need to calibrate your Moss v1 robot to ensure that the leader and
|
||||||
|
|
||||||
You will need to move the follower arm to these positions sequentially:
|
You will need to move the follower arm to these positions sequentially:
|
||||||
|
|
||||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|---|---|---|
|
| ------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| <img src="../media/moss/follower_zero.webp?raw=true" alt="Moss v1 follower arm zero position" title="Moss v1 follower arm zero position" style="width:100%;"> | <img src="../media/moss/follower_rotated.webp?raw=true" alt="Moss v1 follower arm rotated position" title="Moss v1 follower arm rotated position" style="width:100%;"> | <img src="../media/moss/follower_rest.webp?raw=true" alt="Moss v1 follower arm rest position" title="Moss v1 follower arm rest position" style="width:100%;"> |
|
| <img src="../media/moss/follower_zero.webp?raw=true" alt="Moss v1 follower arm zero position" title="Moss v1 follower arm zero position" style="width:100%;"> | <img src="../media/moss/follower_rotated.webp?raw=true" alt="Moss v1 follower arm rotated position" title="Moss v1 follower arm rotated position" style="width:100%;"> | <img src="../media/moss/follower_rest.webp?raw=true" alt="Moss v1 follower arm rest position" title="Moss v1 follower arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
Make sure both arms are connected and run this script to launch manual calibration:
|
Make sure both arms are connected and run this script to launch manual calibration:
|
||||||
|
@ -192,8 +192,8 @@ python lerobot/scripts/control_robot.py \
|
||||||
**Manual calibration of leader arm**
|
**Manual calibration of leader arm**
|
||||||
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||||
|
|
||||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|---|---|---|
|
| ------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| <img src="../media/moss/leader_zero.webp?raw=true" alt="Moss v1 leader arm zero position" title="Moss v1 leader arm zero position" style="width:100%;"> | <img src="../media/moss/leader_rotated.webp?raw=true" alt="Moss v1 leader arm rotated position" title="Moss v1 leader arm rotated position" style="width:100%;"> | <img src="../media/moss/leader_rest.webp?raw=true" alt="Moss v1 leader arm rest position" title="Moss v1 leader arm rest position" style="width:100%;"> |
|
| <img src="../media/moss/leader_zero.webp?raw=true" alt="Moss v1 leader arm zero position" title="Moss v1 leader arm zero position" style="width:100%;"> | <img src="../media/moss/leader_rotated.webp?raw=true" alt="Moss v1 leader arm rotated position" title="Moss v1 leader arm rotated position" style="width:100%;"> | <img src="../media/moss/leader_rest.webp?raw=true" alt="Moss v1 leader arm rest position" title="Moss v1 leader arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
Run this script to launch manual calibration:
|
Run this script to launch manual calibration:
|
||||||
|
@ -293,14 +293,14 @@ python lerobot/scripts/train.py \
|
||||||
--policy.type=act \
|
--policy.type=act \
|
||||||
--output_dir=outputs/train/act_moss_test \
|
--output_dir=outputs/train/act_moss_test \
|
||||||
--job_name=act_moss_test \
|
--job_name=act_moss_test \
|
||||||
--device=cuda \
|
--policy.device=cuda \
|
||||||
--wandb.enable=true
|
--wandb.enable=true
|
||||||
```
|
```
|
||||||
|
|
||||||
Let's explain it:
|
Let's explain it:
|
||||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`.
|
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`.
|
||||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||||
|
|
||||||
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
|
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
|
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
|
||||||
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
||||||
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
||||||
|
@ -30,7 +44,7 @@ pretrained_policy_path = "lerobot/diffusion_pusht"
|
||||||
# OR a path to a local outputs/train folder.
|
# OR a path to a local outputs/train folder.
|
||||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||||
|
|
||||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path, map_location=device)
|
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||||
|
|
||||||
# Initialize evaluation environment to render two observation types:
|
# Initialize evaluation environment to render two observation types:
|
||||||
# an image of the scene and state/position of the agent. The environment
|
# an image of the scene and state/position of the agent. The environment
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
||||||
|
|
||||||
Once you have trained a model with this script, you can try to evaluate it on
|
Once you have trained a model with this script, you can try to evaluate it on
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run.
|
This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run.
|
||||||
> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--device=cpu` (`--device=mps` respectively). However, be advised that the code executes much slower on cpu.
|
> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--policy.device=cpu` (`--policy.device=mps` respectively). However, be advised that the code executes much slower on cpu.
|
||||||
|
|
||||||
|
|
||||||
## The training script
|
## The training script
|
||||||
|
|
|
@ -386,14 +386,14 @@ When you connect your robot for the first time, the [`ManipulatorRobot`](../lero
|
||||||
|
|
||||||
Here are the positions you'll move the follower arm to:
|
Here are the positions you'll move the follower arm to:
|
||||||
|
|
||||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|---|---|---|
|
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| <img src="../media/koch/follower_zero.webp?raw=true" alt="Koch v1.1 follower arm zero position" title="Koch v1.1 follower arm zero position" style="width:100%;"> | <img src="../media/koch/follower_rotated.webp?raw=true" alt="Koch v1.1 follower arm rotated position" title="Koch v1.1 follower arm rotated position" style="width:100%;"> | <img src="../media/koch/follower_rest.webp?raw=true" alt="Koch v1.1 follower arm rest position" title="Koch v1.1 follower arm rest position" style="width:100%;"> |
|
| <img src="../media/koch/follower_zero.webp?raw=true" alt="Koch v1.1 follower arm zero position" title="Koch v1.1 follower arm zero position" style="width:100%;"> | <img src="../media/koch/follower_rotated.webp?raw=true" alt="Koch v1.1 follower arm rotated position" title="Koch v1.1 follower arm rotated position" style="width:100%;"> | <img src="../media/koch/follower_rest.webp?raw=true" alt="Koch v1.1 follower arm rest position" title="Koch v1.1 follower arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
And here are the corresponding positions for the leader arm:
|
And here are the corresponding positions for the leader arm:
|
||||||
|
|
||||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||||
|---|---|---|
|
| ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| <img src="../media/koch/leader_zero.webp?raw=true" alt="Koch v1.1 leader arm zero position" title="Koch v1.1 leader arm zero position" style="width:100%;"> | <img src="../media/koch/leader_rotated.webp?raw=true" alt="Koch v1.1 leader arm rotated position" title="Koch v1.1 leader arm rotated position" style="width:100%;"> | <img src="../media/koch/leader_rest.webp?raw=true" alt="Koch v1.1 leader arm rest position" title="Koch v1.1 leader arm rest position" style="width:100%;"> |
|
| <img src="../media/koch/leader_zero.webp?raw=true" alt="Koch v1.1 leader arm zero position" title="Koch v1.1 leader arm zero position" style="width:100%;"> | <img src="../media/koch/leader_rotated.webp?raw=true" alt="Koch v1.1 leader arm rotated position" title="Koch v1.1 leader arm rotated position" style="width:100%;"> | <img src="../media/koch/leader_rest.webp?raw=true" alt="Koch v1.1 leader arm rest position" title="Koch v1.1 leader arm rest position" style="width:100%;"> |
|
||||||
|
|
||||||
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
|
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
|
||||||
|
@ -898,14 +898,14 @@ python lerobot/scripts/train.py \
|
||||||
--policy.type=act \
|
--policy.type=act \
|
||||||
--output_dir=outputs/train/act_koch_test \
|
--output_dir=outputs/train/act_koch_test \
|
||||||
--job_name=act_koch_test \
|
--job_name=act_koch_test \
|
||||||
--device=cuda \
|
--policy.device=cuda \
|
||||||
--wandb.enable=true
|
--wandb.enable=true
|
||||||
```
|
```
|
||||||
|
|
||||||
Let's explain it:
|
Let's explain it:
|
||||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`.
|
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`.
|
||||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||||
|
|
||||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||||
|
|
|
@ -135,14 +135,14 @@ python lerobot/scripts/train.py \
|
||||||
--policy.type=act \
|
--policy.type=act \
|
||||||
--output_dir=outputs/train/act_aloha_test \
|
--output_dir=outputs/train/act_aloha_test \
|
||||||
--job_name=act_aloha_test \
|
--job_name=act_aloha_test \
|
||||||
--device=cuda \
|
--policy.device=cuda \
|
||||||
--wandb.enable=true
|
--wandb.enable=true
|
||||||
```
|
```
|
||||||
|
|
||||||
Let's explain it:
|
Let's explain it:
|
||||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`.
|
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`.
|
||||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||||
|
|
||||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
||||||
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
|
"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
|
||||||
|
|
||||||
This technique can be useful for debugging and testing purposes, as well as identifying whether a policy
|
This technique can be useful for debugging and testing purposes, as well as identifying whether a policy
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,16 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
# keys
|
# keys
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import packaging.version
|
import packaging.version
|
||||||
|
|
||||||
V2_MESSAGE = """
|
V2_MESSAGE = """
|
||||||
|
|
|
@ -67,7 +67,7 @@ from lerobot.common.datasets.utils import (
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import (
|
from lerobot.common.datasets.video_utils import (
|
||||||
VideoFrame,
|
VideoFrame,
|
||||||
decode_video_frames_torchvision,
|
decode_video_frames,
|
||||||
encode_video_frames,
|
encode_video_frames,
|
||||||
get_video_info,
|
get_video_info,
|
||||||
)
|
)
|
||||||
|
@ -462,8 +462,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||||
True.
|
True.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
|
||||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
|
@ -473,7 +473,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.video_backend = video_backend if video_backend else "pyav"
|
self.video_backend = video_backend if video_backend else "torchcodec"
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
|
||||||
# Unused attributes
|
# Unused attributes
|
||||||
|
@ -707,9 +707,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
item = {}
|
item = {}
|
||||||
for vid_key, query_ts in query_timestamps.items():
|
for vid_key, query_ts in query_timestamps.items():
|
||||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||||
frames = decode_video_frames_torchvision(
|
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
|
||||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
|
||||||
)
|
|
||||||
item[vid_key] = frames.squeeze(0)
|
item[vid_key] = frames.squeeze(0)
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
@ -1029,7 +1027,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.delta_timestamps = None
|
obj.delta_timestamps = None
|
||||||
obj.delta_indices = None
|
obj.delta_indices = None
|
||||||
obj.episode_data_index = None
|
obj.episode_data_index = None
|
||||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
obj.video_backend = video_backend if video_backend is not None else "torchcodec"
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||||
2.1. It will:
|
2.1. It will:
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
|
@ -27,6 +27,35 @@ import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from datasets.features.features import register_feature
|
from datasets.features.features import register_feature
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video_frames(
|
||||||
|
video_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
tolerance_s: float,
|
||||||
|
backend: str = "torchcodec",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Decodes video frames using the specified backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path (Path): Path to the video file.
|
||||||
|
timestamps (list[float]): List of timestamps to extract frames.
|
||||||
|
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||||
|
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Decoded frames.
|
||||||
|
|
||||||
|
Currently supports torchcodec on cpu and pyav.
|
||||||
|
"""
|
||||||
|
if backend == "torchcodec":
|
||||||
|
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||||
|
elif backend in ["pyav", "video_reader"]:
|
||||||
|
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames_torchvision(
|
def decode_video_frames_torchvision(
|
||||||
|
@ -127,6 +156,75 @@ def decode_video_frames_torchvision(
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video_frames_torchcodec(
|
||||||
|
video_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
tolerance_s: float,
|
||||||
|
device: str = "cpu",
|
||||||
|
log_loaded_timestamps: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||||
|
|
||||||
|
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
||||||
|
|
||||||
|
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||||
|
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||||
|
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
||||||
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||||
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||||
|
"""
|
||||||
|
# initialize video decoder
|
||||||
|
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
||||||
|
loaded_frames = []
|
||||||
|
loaded_ts = []
|
||||||
|
# get metadata for frame information
|
||||||
|
metadata = decoder.metadata
|
||||||
|
average_fps = metadata.average_fps
|
||||||
|
|
||||||
|
# convert timestamps to frame indices
|
||||||
|
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||||
|
|
||||||
|
# retrieve frames based on indices
|
||||||
|
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||||
|
|
||||||
|
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
||||||
|
loaded_frames.append(frame)
|
||||||
|
loaded_ts.append(pts.item())
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||||
|
|
||||||
|
query_ts = torch.tensor(timestamps)
|
||||||
|
loaded_ts = torch.tensor(loaded_ts)
|
||||||
|
|
||||||
|
# compute distances between each query timestamp and loaded timestamps
|
||||||
|
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
||||||
|
min_, argmin_ = dist.min(1)
|
||||||
|
|
||||||
|
is_within_tol = min_ < tolerance_s
|
||||||
|
assert is_within_tol.all(), (
|
||||||
|
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||||
|
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||||
|
"This might be due to synchronization issues with timestamps during data collection."
|
||||||
|
"To be safe, we advise to ignore this item during training."
|
||||||
|
f"\nqueried timestamps: {query_ts}"
|
||||||
|
f"\nloaded timestamps: {loaded_ts}"
|
||||||
|
f"\nvideo: {video_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# get closest frames to the query timestamps
|
||||||
|
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||||
|
closest_ts = loaded_ts[argmin_]
|
||||||
|
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(f"{closest_ts=}")
|
||||||
|
|
||||||
|
# convert to float32 in [0,1] range (channel first)
|
||||||
|
closest_frames = closest_frames.type(torch.float32) / 255
|
||||||
|
|
||||||
|
assert len(timestamps) == len(closest_frames)
|
||||||
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
def encode_video_frames(
|
def encode_video_frames(
|
||||||
imgs_dir: Path | str,
|
imgs_dir: Path | str,
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
|
|
|
@ -1 +1,15 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
|
@ -1 +1,15 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .optimizers import OptimizerConfig as OptimizerConfig
|
from .optimizers import OptimizerConfig as OptimizerConfig
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .act.configuration_act import ACTConfig as ACTConfig
|
from .act.configuration_act import ACTConfig as ACTConfig
|
||||||
from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig
|
from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig
|
||||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
|
@ -83,7 +82,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||||
|
|
||||||
def make_policy(
|
def make_policy(
|
||||||
cfg: PreTrainedConfig,
|
cfg: PreTrainedConfig,
|
||||||
device: str | torch.device,
|
|
||||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||||
env_cfg: EnvConfig | None = None,
|
env_cfg: EnvConfig | None = None,
|
||||||
) -> PreTrainedPolicy:
|
) -> PreTrainedPolicy:
|
||||||
|
@ -95,7 +93,6 @@ def make_policy(
|
||||||
Args:
|
Args:
|
||||||
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
||||||
be loaded with the weights from that path.
|
be loaded with the weights from that path.
|
||||||
device (str): the device to load the policy onto.
|
|
||||||
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
||||||
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
||||||
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
||||||
|
@ -103,7 +100,7 @@ def make_policy(
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: Either ds_meta or env and env_cfg must be provided.
|
ValueError: Either ds_meta or env and env_cfg must be provided.
|
||||||
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility)
|
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
PreTrainedPolicy: _description_
|
PreTrainedPolicy: _description_
|
||||||
|
@ -118,7 +115,7 @@ def make_policy(
|
||||||
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
|
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
|
||||||
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
|
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
|
||||||
# slower than running natively on MPS.
|
# slower than running natively on MPS.
|
||||||
if cfg.type == "vqbet" and str(device) == "mps":
|
if cfg.type == "vqbet" and cfg.device == "mps":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Current implementation of VQBeT does not support `mps` backend. "
|
"Current implementation of VQBeT does not support `mps` backend. "
|
||||||
"Please use `cpu` or `cuda` backend."
|
"Please use `cpu` or `cuda` backend."
|
||||||
|
@ -152,7 +149,7 @@ def make_policy(
|
||||||
# Make a fresh policy.
|
# Make a fresh policy.
|
||||||
policy = policy_cls(**kwargs)
|
policy = policy_cls(**kwargs)
|
||||||
|
|
||||||
policy.to(device)
|
policy.to(cfg.device)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from lerobot.common.optim.optimizers import AdamWConfig
|
from lerobot.common.optim.optimizers import AdamWConfig
|
||||||
|
@ -76,6 +90,7 @@ class PI0Config(PreTrainedConfig):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
|
# TODO(Steven): Validate device and amp? in all policy configs?
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
if self.n_action_steps > self.chunk_size:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
@ -31,7 +45,7 @@ def main():
|
||||||
|
|
||||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||||
cfg.pretrained_path = ckpt_torch_dir
|
cfg.pretrained_path = ckpt_torch_dir
|
||||||
policy = make_policy(cfg, device, ds_meta=dataset.meta)
|
policy = make_policy(cfg, ds_meta=dataset.meta)
|
||||||
|
|
||||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -87,7 +101,7 @@ def main():
|
||||||
|
|
||||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||||
cfg.pretrained_path = ckpt_torch_dir
|
cfg.pretrained_path = ckpt_torch_dir
|
||||||
policy = make_policy(cfg, device, dataset_meta)
|
policy = make_policy(cfg, dataset_meta)
|
||||||
|
|
||||||
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
||||||
# loss_dict["loss"].backward()
|
# loss_dict["loss"].backward()
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from transformers import GemmaConfig, PaliGemmaConfig
|
from transformers import GemmaConfig, PaliGemmaConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Convert pi0 parameters from Jax to Pytorch
|
Convert pi0 parameters from Jax to Pytorch
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -1,3 +1,16 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -73,7 +86,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||||
cache_dir: str | Path | None = None,
|
cache_dir: str | Path | None = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
map_location: str = "cpu",
|
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> T:
|
) -> T:
|
||||||
|
@ -98,7 +110,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||||
if os.path.isdir(model_id):
|
if os.path.isdir(model_id):
|
||||||
print("Loading weights from local directory")
|
print("Loading weights from local directory")
|
||||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||||
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
model_file = hf_hub_download(
|
model_file = hf_hub_download(
|
||||||
|
@ -112,13 +124,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||||
token=token,
|
token=token,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||||
except HfHubHTTPError as e:
|
except HfHubHTTPError as e:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
policy.to(map_location)
|
policy.to(config.device)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This file contains utilities for recording frames from Intel Realsense cameras.
|
This file contains utilities for recording frames from Intel Realsense cameras.
|
||||||
"""
|
"""
|
||||||
|
@ -34,7 +48,7 @@ def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
|
||||||
connected to the computer.
|
connected to the computer.
|
||||||
"""
|
"""
|
||||||
if mock:
|
if mock:
|
||||||
import tests.mock_pyrealsense2 as rs
|
import tests.cameras.mock_pyrealsense2 as rs
|
||||||
else:
|
else:
|
||||||
import pyrealsense2 as rs
|
import pyrealsense2 as rs
|
||||||
|
|
||||||
|
@ -86,7 +100,7 @@ def save_images_from_cameras(
|
||||||
serial_numbers = [cam["serial_number"] for cam in camera_infos]
|
serial_numbers = [cam["serial_number"] for cam in camera_infos]
|
||||||
|
|
||||||
if mock:
|
if mock:
|
||||||
import tests.mock_cv2 as cv2
|
import tests.cameras.mock_cv2 as cv2
|
||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
@ -100,7 +114,7 @@ def save_images_from_cameras(
|
||||||
camera = IntelRealSenseCamera(config)
|
camera = IntelRealSenseCamera(config)
|
||||||
camera.connect()
|
camera.connect()
|
||||||
print(
|
print(
|
||||||
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
|
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})"
|
||||||
)
|
)
|
||||||
cameras.append(camera)
|
cameras.append(camera)
|
||||||
|
|
||||||
|
@ -210,9 +224,20 @@ class IntelRealSenseCamera:
|
||||||
self.serial_number = self.find_serial_number_from_name(config.name)
|
self.serial_number = self.find_serial_number_from_name(config.name)
|
||||||
else:
|
else:
|
||||||
self.serial_number = config.serial_number
|
self.serial_number = config.serial_number
|
||||||
|
|
||||||
|
# Store the raw (capture) resolution from the config.
|
||||||
|
self.capture_width = config.width
|
||||||
|
self.capture_height = config.height
|
||||||
|
|
||||||
|
# If rotated by ±90, swap width and height.
|
||||||
|
if config.rotation in [-90, 90]:
|
||||||
|
self.width = config.height
|
||||||
|
self.height = config.width
|
||||||
|
else:
|
||||||
|
self.width = config.width
|
||||||
|
self.height = config.height
|
||||||
|
|
||||||
self.fps = config.fps
|
self.fps = config.fps
|
||||||
self.width = config.width
|
|
||||||
self.height = config.height
|
|
||||||
self.channels = config.channels
|
self.channels = config.channels
|
||||||
self.color_mode = config.color_mode
|
self.color_mode = config.color_mode
|
||||||
self.use_depth = config.use_depth
|
self.use_depth = config.use_depth
|
||||||
|
@ -228,11 +253,10 @@ class IntelRealSenseCamera:
|
||||||
self.logs = {}
|
self.logs = {}
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_cv2 as cv2
|
import tests.cameras.mock_cv2 as cv2
|
||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
# TODO(alibets): Do we keep original width/height or do we define them after rotation?
|
|
||||||
self.rotation = None
|
self.rotation = None
|
||||||
if config.rotation == -90:
|
if config.rotation == -90:
|
||||||
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||||
|
@ -263,22 +287,26 @@ class IntelRealSenseCamera:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_pyrealsense2 as rs
|
import tests.cameras.mock_pyrealsense2 as rs
|
||||||
else:
|
else:
|
||||||
import pyrealsense2 as rs
|
import pyrealsense2 as rs
|
||||||
|
|
||||||
config = rs.config()
|
config = rs.config()
|
||||||
config.enable_device(str(self.serial_number))
|
config.enable_device(str(self.serial_number))
|
||||||
|
|
||||||
if self.fps and self.width and self.height:
|
if self.fps and self.capture_width and self.capture_height:
|
||||||
# TODO(rcadene): can we set rgb8 directly?
|
# TODO(rcadene): can we set rgb8 directly?
|
||||||
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
|
config.enable_stream(
|
||||||
|
rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
config.enable_stream(rs.stream.color)
|
config.enable_stream(rs.stream.color)
|
||||||
|
|
||||||
if self.use_depth:
|
if self.use_depth:
|
||||||
if self.fps and self.width and self.height:
|
if self.fps and self.capture_width and self.capture_height:
|
||||||
config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
|
config.enable_stream(
|
||||||
|
rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
config.enable_stream(rs.stream.depth)
|
config.enable_stream(rs.stream.depth)
|
||||||
|
|
||||||
|
@ -316,18 +344,18 @@ class IntelRealSenseCamera:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
||||||
)
|
)
|
||||||
if self.width is not None and self.width != actual_width:
|
if self.capture_width is not None and self.capture_width != actual_width:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
f"Can't set {self.capture_width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
||||||
)
|
)
|
||||||
if self.height is not None and self.height != actual_height:
|
if self.capture_height is not None and self.capture_height != actual_height:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
f"Can't set {self.capture_height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.fps = round(actual_fps)
|
self.fps = round(actual_fps)
|
||||||
self.width = round(actual_width)
|
self.capture_width = round(actual_width)
|
||||||
self.height = round(actual_height)
|
self.capture_height = round(actual_height)
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
|
|
||||||
|
@ -347,7 +375,7 @@ class IntelRealSenseCamera:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_cv2 as cv2
|
import tests.cameras.mock_cv2 as cv2
|
||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
@ -373,7 +401,7 @@ class IntelRealSenseCamera:
|
||||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
h, w, _ = color_image.shape
|
h, w, _ = color_image.shape
|
||||||
if h != self.height or w != self.width:
|
if h != self.capture_height or w != self.capture_width:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||||
)
|
)
|
||||||
|
@ -395,7 +423,7 @@ class IntelRealSenseCamera:
|
||||||
depth_map = np.asanyarray(depth_frame.get_data())
|
depth_map = np.asanyarray(depth_frame.get_data())
|
||||||
|
|
||||||
h, w = depth_map.shape
|
h, w = depth_map.shape
|
||||||
if h != self.height or w != self.width:
|
if h != self.capture_height or w != self.capture_width:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
|
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
|
||||||
"""
|
"""
|
||||||
|
@ -66,7 +80,7 @@ def _find_cameras(
|
||||||
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
|
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
|
||||||
) -> list[int | str]:
|
) -> list[int | str]:
|
||||||
if mock:
|
if mock:
|
||||||
import tests.mock_cv2 as cv2
|
import tests.cameras.mock_cv2 as cv2
|
||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
@ -130,8 +144,8 @@ def save_images_from_cameras(
|
||||||
camera = OpenCVCamera(config)
|
camera = OpenCVCamera(config)
|
||||||
camera.connect()
|
camera.connect()
|
||||||
print(
|
print(
|
||||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
|
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.capture_width}, "
|
||||||
f"height={camera.height}, color_mode={camera.color_mode})"
|
f"height={camera.capture_height}, color_mode={camera.color_mode})"
|
||||||
)
|
)
|
||||||
cameras.append(camera)
|
cameras.append(camera)
|
||||||
|
|
||||||
|
@ -230,9 +244,19 @@ class OpenCVCamera:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
|
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
|
||||||
|
|
||||||
|
# Store the raw (capture) resolution from the config.
|
||||||
|
self.capture_width = config.width
|
||||||
|
self.capture_height = config.height
|
||||||
|
|
||||||
|
# If rotated by ±90, swap width and height.
|
||||||
|
if config.rotation in [-90, 90]:
|
||||||
|
self.width = config.height
|
||||||
|
self.height = config.width
|
||||||
|
else:
|
||||||
|
self.width = config.width
|
||||||
|
self.height = config.height
|
||||||
|
|
||||||
self.fps = config.fps
|
self.fps = config.fps
|
||||||
self.width = config.width
|
|
||||||
self.height = config.height
|
|
||||||
self.channels = config.channels
|
self.channels = config.channels
|
||||||
self.color_mode = config.color_mode
|
self.color_mode = config.color_mode
|
||||||
self.mock = config.mock
|
self.mock = config.mock
|
||||||
|
@ -245,11 +269,10 @@ class OpenCVCamera:
|
||||||
self.logs = {}
|
self.logs = {}
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_cv2 as cv2
|
import tests.cameras.mock_cv2 as cv2
|
||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
# TODO(aliberts): Do we keep original width/height or do we define them after rotation?
|
|
||||||
self.rotation = None
|
self.rotation = None
|
||||||
if config.rotation == -90:
|
if config.rotation == -90:
|
||||||
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||||
|
@ -263,7 +286,7 @@ class OpenCVCamera:
|
||||||
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_cv2 as cv2
|
import tests.cameras.mock_cv2 as cv2
|
||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
@ -271,10 +294,20 @@ class OpenCVCamera:
|
||||||
# when other threads are used to save the images.
|
# when other threads are used to save the images.
|
||||||
cv2.setNumThreads(1)
|
cv2.setNumThreads(1)
|
||||||
|
|
||||||
|
backend = (
|
||||||
|
cv2.CAP_V4L2
|
||||||
|
if platform.system() == "Linux"
|
||||||
|
else cv2.CAP_DSHOW
|
||||||
|
if platform.system() == "Windows"
|
||||||
|
else cv2.CAP_AVFOUNDATION
|
||||||
|
if platform.system() == "Darwin"
|
||||||
|
else cv2.CAP_ANY
|
||||||
|
)
|
||||||
|
|
||||||
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
||||||
# First create a temporary camera trying to access `camera_index`,
|
# First create a temporary camera trying to access `camera_index`,
|
||||||
# and verify it is a valid camera by calling `isOpened`.
|
# and verify it is a valid camera by calling `isOpened`.
|
||||||
tmp_camera = cv2.VideoCapture(camera_idx)
|
tmp_camera = cv2.VideoCapture(camera_idx, backend)
|
||||||
is_camera_open = tmp_camera.isOpened()
|
is_camera_open = tmp_camera.isOpened()
|
||||||
# Release camera to make it accessible for `find_camera_indices`
|
# Release camera to make it accessible for `find_camera_indices`
|
||||||
tmp_camera.release()
|
tmp_camera.release()
|
||||||
|
@ -297,14 +330,14 @@ class OpenCVCamera:
|
||||||
# Secondly, create the camera that will be used downstream.
|
# Secondly, create the camera that will be used downstream.
|
||||||
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
|
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
|
||||||
# needs to be re-created.
|
# needs to be re-created.
|
||||||
self.camera = cv2.VideoCapture(camera_idx)
|
self.camera = cv2.VideoCapture(camera_idx, backend)
|
||||||
|
|
||||||
if self.fps is not None:
|
if self.fps is not None:
|
||||||
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
||||||
if self.width is not None:
|
if self.capture_width is not None:
|
||||||
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
|
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.capture_width)
|
||||||
if self.height is not None:
|
if self.capture_height is not None:
|
||||||
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
|
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.capture_height)
|
||||||
|
|
||||||
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
|
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
|
||||||
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
|
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
|
||||||
|
@ -316,19 +349,22 @@ class OpenCVCamera:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
|
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
|
||||||
)
|
)
|
||||||
if self.width is not None and not math.isclose(self.width, actual_width, rel_tol=1e-3):
|
if self.capture_width is not None and not math.isclose(
|
||||||
|
self.capture_width, actual_width, rel_tol=1e-3
|
||||||
|
):
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
f"Can't set {self.capture_width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
||||||
)
|
)
|
||||||
if self.height is not None and not math.isclose(self.height, actual_height, rel_tol=1e-3):
|
if self.capture_height is not None and not math.isclose(
|
||||||
|
self.capture_height, actual_height, rel_tol=1e-3
|
||||||
|
):
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
f"Can't set {self.capture_height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.fps = round(actual_fps)
|
self.fps = round(actual_fps)
|
||||||
self.width = round(actual_width)
|
self.capture_width = round(actual_width)
|
||||||
self.height = round(actual_height)
|
self.capture_height = round(actual_height)
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
|
|
||||||
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
|
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
|
||||||
|
@ -362,14 +398,14 @@ class OpenCVCamera:
|
||||||
# so we convert the image color from BGR to RGB.
|
# so we convert the image color from BGR to RGB.
|
||||||
if requested_color_mode == "rgb":
|
if requested_color_mode == "rgb":
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_cv2 as cv2
|
import tests.cameras.mock_cv2 as cv2
|
||||||
else:
|
else:
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
h, w, _ = color_image.shape
|
h, w, _ = color_image.shape
|
||||||
if h != self.height or w != self.width:
|
if h != self.capture_height or w != self.capture_width:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -31,7 +45,7 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[C
|
||||||
|
|
||||||
cameras[key] = IntelRealSenseCamera(cfg)
|
cameras[key] = IntelRealSenseCamera(cfg)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
|
||||||
|
|
||||||
return cameras
|
return cameras
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,25 @@
|
||||||
import logging
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -43,11 +54,6 @@ class RecordControlConfig(ControlConfig):
|
||||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||||
root: str | Path | None = None
|
root: str | Path | None = None
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
|
|
||||||
device: str | None = None # cuda | cpu | mps
|
|
||||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
|
||||||
# automatic gradient scaling is used.
|
|
||||||
use_amp: bool | None = None
|
|
||||||
# Limit the frames per second. By default, uses the policy fps.
|
# Limit the frames per second. By default, uses the policy fps.
|
||||||
fps: int | None = None
|
fps: int | None = None
|
||||||
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
||||||
|
@ -90,27 +96,6 @@ class RecordControlConfig(ControlConfig):
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
self.policy.pretrained_path = policy_path
|
self.policy.pretrained_path = policy_path
|
||||||
|
|
||||||
# When no device or use_amp are given, use the one from training config.
|
|
||||||
if self.device is None or self.use_amp is None:
|
|
||||||
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
|
|
||||||
if self.device is None:
|
|
||||||
self.device = train_cfg.device
|
|
||||||
if self.use_amp is None:
|
|
||||||
self.use_amp = train_cfg.use_amp
|
|
||||||
|
|
||||||
# Automatically switch to available device if necessary
|
|
||||||
if not is_torch_device_available(self.device):
|
|
||||||
auto_device = auto_select_torch_device()
|
|
||||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
|
||||||
self.device = auto_device
|
|
||||||
|
|
||||||
# Automatically deactivate AMP if necessary
|
|
||||||
if self.use_amp and not is_amp_available(self.device):
|
|
||||||
logging.warning(
|
|
||||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
|
||||||
)
|
|
||||||
self.use_amp = False
|
|
||||||
|
|
||||||
|
|
||||||
@ControlConfig.register_subclass("replay")
|
@ControlConfig.register_subclass("replay")
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
########################################################################################
|
########################################################################################
|
||||||
# Utilities
|
# Utilities
|
||||||
########################################################################################
|
########################################################################################
|
||||||
|
@ -18,6 +32,7 @@ from termcolor import colored
|
||||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import get_features_from_robot
|
from lerobot.common.datasets.utils import get_features_from_robot
|
||||||
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from lerobot.common.robot_devices.utils import busy_wait
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
||||||
|
@ -179,8 +194,6 @@ def record_episode(
|
||||||
episode_time_s,
|
episode_time_s,
|
||||||
display_cameras,
|
display_cameras,
|
||||||
policy,
|
policy,
|
||||||
device,
|
|
||||||
use_amp,
|
|
||||||
fps,
|
fps,
|
||||||
single_task,
|
single_task,
|
||||||
):
|
):
|
||||||
|
@ -191,8 +204,6 @@ def record_episode(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
events=events,
|
events=events,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
device=device,
|
|
||||||
use_amp=use_amp,
|
|
||||||
fps=fps,
|
fps=fps,
|
||||||
teleoperate=policy is None,
|
teleoperate=policy is None,
|
||||||
single_task=single_task,
|
single_task=single_task,
|
||||||
|
@ -207,9 +218,7 @@ def control_loop(
|
||||||
display_cameras=False,
|
display_cameras=False,
|
||||||
dataset: LeRobotDataset | None = None,
|
dataset: LeRobotDataset | None = None,
|
||||||
events=None,
|
events=None,
|
||||||
policy=None,
|
policy: PreTrainedPolicy = None,
|
||||||
device: torch.device | str | None = None,
|
|
||||||
use_amp: bool | None = None,
|
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
single_task: str | None = None,
|
single_task: str | None = None,
|
||||||
):
|
):
|
||||||
|
@ -232,9 +241,6 @@ def control_loop(
|
||||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||||
|
|
||||||
if isinstance(device, str):
|
|
||||||
device = get_safe_torch_device(device)
|
|
||||||
|
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
while timestamp < control_time_s:
|
while timestamp < control_time_s:
|
||||||
|
@ -246,7 +252,9 @@ def control_loop(
|
||||||
observation = robot.capture_observation()
|
observation = robot.capture_observation()
|
||||||
|
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
pred_action = predict_action(observation, policy, device, use_amp)
|
pred_action = predict_action(
|
||||||
|
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||||
|
)
|
||||||
# Action can eventually be clipped using `max_relative_target`,
|
# Action can eventually be clipped using `max_relative_target`,
|
||||||
# so action actually sent is saved in the dataset.
|
# so action actually sent is saved in the dataset.
|
||||||
action = robot.send_action(pred_action)
|
action = robot.send_action(pred_action)
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
@ -318,7 +332,7 @@ class DynamixelMotorsBus:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_dynamixel_sdk as dxl
|
import tests.motors.mock_dynamixel_sdk as dxl
|
||||||
else:
|
else:
|
||||||
import dynamixel_sdk as dxl
|
import dynamixel_sdk as dxl
|
||||||
|
|
||||||
|
@ -342,7 +356,7 @@ class DynamixelMotorsBus:
|
||||||
|
|
||||||
def reconnect(self):
|
def reconnect(self):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_dynamixel_sdk as dxl
|
import tests.motors.mock_dynamixel_sdk as dxl
|
||||||
else:
|
else:
|
||||||
import dynamixel_sdk as dxl
|
import dynamixel_sdk as dxl
|
||||||
|
|
||||||
|
@ -632,7 +646,7 @@ class DynamixelMotorsBus:
|
||||||
|
|
||||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_dynamixel_sdk as dxl
|
import tests.motors.mock_dynamixel_sdk as dxl
|
||||||
else:
|
else:
|
||||||
import dynamixel_sdk as dxl
|
import dynamixel_sdk as dxl
|
||||||
|
|
||||||
|
@ -677,7 +691,7 @@ class DynamixelMotorsBus:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_dynamixel_sdk as dxl
|
import tests.motors.mock_dynamixel_sdk as dxl
|
||||||
else:
|
else:
|
||||||
import dynamixel_sdk as dxl
|
import dynamixel_sdk as dxl
|
||||||
|
|
||||||
|
@ -743,7 +757,7 @@ class DynamixelMotorsBus:
|
||||||
|
|
||||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_dynamixel_sdk as dxl
|
import tests.motors.mock_dynamixel_sdk as dxl
|
||||||
else:
|
else:
|
||||||
import dynamixel_sdk as dxl
|
import dynamixel_sdk as dxl
|
||||||
|
|
||||||
|
@ -779,7 +793,7 @@ class DynamixelMotorsBus:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_dynamixel_sdk as dxl
|
import tests.motors.mock_dynamixel_sdk as dxl
|
||||||
else:
|
else:
|
||||||
import dynamixel_sdk as dxl
|
import dynamixel_sdk as dxl
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
@ -299,7 +313,7 @@ class FeetechMotorsBus:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_scservo_sdk as scs
|
import tests.motors.mock_scservo_sdk as scs
|
||||||
else:
|
else:
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
@ -323,7 +337,7 @@ class FeetechMotorsBus:
|
||||||
|
|
||||||
def reconnect(self):
|
def reconnect(self):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_scservo_sdk as scs
|
import tests.motors.mock_scservo_sdk as scs
|
||||||
else:
|
else:
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
@ -650,7 +664,7 @@ class FeetechMotorsBus:
|
||||||
|
|
||||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_scservo_sdk as scs
|
import tests.motors.mock_scservo_sdk as scs
|
||||||
else:
|
else:
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
@ -688,7 +702,7 @@ class FeetechMotorsBus:
|
||||||
|
|
||||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_scservo_sdk as scs
|
import tests.motors.mock_scservo_sdk as scs
|
||||||
else:
|
else:
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
@ -768,7 +782,7 @@ class FeetechMotorsBus:
|
||||||
|
|
||||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_scservo_sdk as scs
|
import tests.motors.mock_scservo_sdk as scs
|
||||||
else:
|
else:
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
@ -804,7 +818,7 @@ class FeetechMotorsBus:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
if self.mock:
|
if self.mock:
|
||||||
import tests.mock_scservo_sdk as scs
|
import tests.motors.mock_scservo_sdk as scs
|
||||||
else:
|
else:
|
||||||
import scservo_sdk as scs
|
import scservo_sdk as scs
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from lerobot.common.robot_devices.motors.configs import (
|
from lerobot.common.robot_devices.motors.configs import (
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""Logic to calibrate a robot arm built with dynamixel motors"""
|
"""Logic to calibrate a robot arm built with dynamixel motors"""
|
||||||
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""Logic to calibrate a robot arm built with feetech motors"""
|
"""Logic to calibrate a robot arm built with feetech motors"""
|
||||||
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""Contains logic to instantiate a robot, read information from its motors and cameras,
|
"""Contains logic to instantiate a robot, read information from its motors and cameras,
|
||||||
and send orders to its motors.
|
and send orders to its motors.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from lerobot.common.robot_devices.robots.configs import (
|
from lerobot.common.robot_devices.robots.configs import (
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Any, Type, TypeVar
|
from typing import Any, Type, TypeVar
|
||||||
|
|
|
@ -51,8 +51,10 @@ def auto_select_torch_device() -> torch.device:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
||||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||||
|
try_device = str(try_device)
|
||||||
match try_device:
|
match try_device:
|
||||||
case "cuda":
|
case "cuda":
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
|
@ -85,6 +87,7 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
|
||||||
|
|
||||||
|
|
||||||
def is_torch_device_available(try_device: str) -> bool:
|
def is_torch_device_available(try_device: str) -> bool:
|
||||||
|
try_device = str(try_device) # Ensure try_device is a string
|
||||||
if try_device == "cuda":
|
if try_device == "cuda":
|
||||||
return torch.cuda.is_available()
|
return torch.cuda.is_available()
|
||||||
elif try_device == "mps":
|
elif try_device == "mps":
|
||||||
|
@ -92,7 +95,7 @@ def is_torch_device_available(try_device: str) -> bool:
|
||||||
elif try_device == "cpu":
|
elif try_device == "cpu":
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown device '{try_device}.")
|
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
||||||
|
|
||||||
|
|
||||||
def is_amp_available(device: str):
|
def is_amp_available(device: str):
|
||||||
|
|
|
@ -1,14 +1,26 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lerobot.common import envs, policies # noqa: F401
|
from lerobot.common import envs, policies # noqa: F401
|
||||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.default import EvalConfig
|
from lerobot.configs.default import EvalConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -21,11 +33,6 @@ class EvalPipelineConfig:
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
output_dir: Path | None = None
|
output_dir: Path | None = None
|
||||||
job_name: str | None = None
|
job_name: str | None = None
|
||||||
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
|
|
||||||
device: str | None = None # cuda | cpu | mps
|
|
||||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
|
||||||
# automatic gradient scaling is used.
|
|
||||||
use_amp: bool = False
|
|
||||||
seed: int | None = 1000
|
seed: int | None = 1000
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -36,27 +43,6 @@ class EvalPipelineConfig:
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
self.policy.pretrained_path = policy_path
|
self.policy.pretrained_path = policy_path
|
||||||
|
|
||||||
# When no device or use_amp are given, use the one from training config.
|
|
||||||
if self.device is None or self.use_amp is None:
|
|
||||||
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
|
|
||||||
if self.device is None:
|
|
||||||
self.device = train_cfg.device
|
|
||||||
if self.use_amp is None:
|
|
||||||
self.use_amp = train_cfg.use_amp
|
|
||||||
|
|
||||||
# Automatically switch to available device if necessary
|
|
||||||
if not is_torch_device_available(self.device):
|
|
||||||
auto_device = auto_select_torch_device()
|
|
||||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
|
||||||
self.device = auto_device
|
|
||||||
|
|
||||||
# Automatically deactivate AMP if necessary
|
|
||||||
if self.use_amp and not is_amp_available(self.device):
|
|
||||||
logging.warning(
|
|
||||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
|
||||||
)
|
|
||||||
self.use_amp = False
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
||||||
|
@ -73,11 +59,6 @@ class EvalPipelineConfig:
|
||||||
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||||
self.output_dir = Path("outputs/eval") / eval_dir
|
self.output_dir = Path("outputs/eval") / eval_dir
|
||||||
|
|
||||||
if self.device is None:
|
|
||||||
raise ValueError("Set one of the following device: cuda, cpu or mps")
|
|
||||||
elif self.device == "cuda" and self.use_amp is None:
|
|
||||||
raise ValueError("Set 'use_amp' to True or False.")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_path_fields__(cls) -> list[str]:
|
def __get_path_fields__(cls) -> list[str]:
|
||||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||||
|
|
|
@ -1,4 +1,19 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import pkgutil
|
||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentError
|
from argparse import ArgumentError
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
@ -10,6 +25,7 @@ import draccus
|
||||||
from lerobot.common.utils.utils import has_method
|
from lerobot.common.utils.utils import has_method
|
||||||
|
|
||||||
PATH_KEY = "path"
|
PATH_KEY = "path"
|
||||||
|
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||||
draccus.set_config_type("json")
|
draccus.set_config_type("json")
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,6 +61,86 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
|
||||||
|
"""Parse plugin-related arguments from command-line arguments.
|
||||||
|
|
||||||
|
This function extracts arguments from command-line arguments that match a specified suffix pattern.
|
||||||
|
It processes arguments in the format '--key=value' and returns them as a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_arg_suffix (str): The suffix to identify plugin-related arguments.
|
||||||
|
cli_args (Sequence[str]): A sequence of command-line arguments to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing the parsed plugin arguments where:
|
||||||
|
- Keys are the argument names (with '--' prefix removed if present)
|
||||||
|
- Values are the corresponding argument values
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> args = ['--env.discover_packages_path=my_package',
|
||||||
|
... '--other_arg=value']
|
||||||
|
>>> parse_plugin_args('discover_packages_path', args)
|
||||||
|
{'env.discover_packages_path': 'my_package'}
|
||||||
|
"""
|
||||||
|
plugin_args = {}
|
||||||
|
for arg in args:
|
||||||
|
if "=" in arg and plugin_arg_suffix in arg:
|
||||||
|
key, value = arg.split("=", 1)
|
||||||
|
# Remove leading '--' if present
|
||||||
|
if key.startswith("--"):
|
||||||
|
key = key[2:]
|
||||||
|
plugin_args[key] = value
|
||||||
|
return plugin_args
|
||||||
|
|
||||||
|
|
||||||
|
class PluginLoadError(Exception):
|
||||||
|
"""Raised when a plugin fails to load."""
|
||||||
|
|
||||||
|
|
||||||
|
def load_plugin(plugin_path: str) -> None:
|
||||||
|
"""Load and initialize a plugin from a given Python package path.
|
||||||
|
|
||||||
|
This function attempts to load a plugin by importing its package and any submodules.
|
||||||
|
Plugin registration is expected to happen during package initialization, i.e. when
|
||||||
|
the package is imported the gym environment should be registered and the config classes
|
||||||
|
registered with their parents using the `register_subclass` decorator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_path (str): The Python package path to the plugin (e.g. "mypackage.plugins.myplugin")
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PluginLoadError: If the plugin cannot be loaded due to import errors or if the package path is invalid.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> load_plugin("external_plugin.core") # Loads plugin from external package
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- The plugin package should handle its own registration during import
|
||||||
|
- All submodules in the plugin package will be imported
|
||||||
|
- Implementation follows the plugin discovery pattern from Python packaging guidelines
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
package_module = importlib.import_module(plugin_path, __package__)
|
||||||
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
|
raise PluginLoadError(
|
||||||
|
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def iter_namespace(ns_pkg):
|
||||||
|
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
|
||||||
|
|
||||||
|
try:
|
||||||
|
for _finder, pkg_name, _ispkg in iter_namespace(package_module):
|
||||||
|
importlib.import_module(pkg_name)
|
||||||
|
except ImportError as e:
|
||||||
|
raise PluginLoadError(
|
||||||
|
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||||
return parse_arg(f"{field_name}.{PATH_KEY}", args)
|
return parse_arg(f"{field_name}.{PATH_KEY}", args)
|
||||||
|
|
||||||
|
@ -92,10 +188,13 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||||
|
|
||||||
def wrap(config_path: Path | None = None):
|
def wrap(config_path: Path | None = None):
|
||||||
"""
|
"""
|
||||||
HACK: Similar to draccus.wrap but does two additional things:
|
HACK: Similar to draccus.wrap but does three additional things:
|
||||||
- Will remove '.path' arguments from CLI in order to process them later on.
|
- Will remove '.path' arguments from CLI in order to process them later on.
|
||||||
- If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will
|
- If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will
|
||||||
initialize it from there to allow to fetch configs from the hub directly
|
initialize it from there to allow to fetch configs from the hub directly
|
||||||
|
- Will load plugins specified in the CLI arguments. These plugins will typically register
|
||||||
|
their own subclasses of config classes, so that draccus can find the right class to instantiate
|
||||||
|
from the CLI '.type' arguments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper_outer(fn):
|
def wrapper_outer(fn):
|
||||||
|
@ -108,6 +207,14 @@ def wrap(config_path: Path | None = None):
|
||||||
args = args[1:]
|
args = args[1:]
|
||||||
else:
|
else:
|
||||||
cli_args = sys.argv[1:]
|
cli_args = sys.argv[1:]
|
||||||
|
plugin_args = parse_plugin_args(PLUGIN_DISCOVERY_SUFFIX, cli_args)
|
||||||
|
for plugin_cli_arg, plugin_path in plugin_args.items():
|
||||||
|
try:
|
||||||
|
load_plugin(plugin_path)
|
||||||
|
except PluginLoadError as e:
|
||||||
|
# add the relevant CLI arg to the error message
|
||||||
|
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
|
||||||
|
cli_args = filter_arg(plugin_cli_arg, cli_args)
|
||||||
config_path_cli = parse_arg("config_path", cli_args)
|
config_path_cli = parse_arg("config_path", cli_args)
|
||||||
if has_method(argtype, "__get_path_fields__"):
|
if has_method(argtype, "__get_path_fields__"):
|
||||||
path_fields = argtype.__get_path_fields__()
|
path_fields = argtype.__get_path_fields__()
|
||||||
|
|
|
@ -1,4 +1,18 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import abc
|
import abc
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -12,6 +26,7 @@ from huggingface_hub.errors import HfHubHTTPError
|
||||||
from lerobot.common.optim.optimizers import OptimizerConfig
|
from lerobot.common.optim.optimizers import OptimizerConfig
|
||||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||||
from lerobot.common.utils.hub import HubMixin
|
from lerobot.common.utils.hub import HubMixin
|
||||||
|
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
|
||||||
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
||||||
|
@ -40,8 +55,24 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
|
||||||
|
device: str | None = None # cuda | cpu | mp
|
||||||
|
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||||
|
# automatic gradient scaling is used.
|
||||||
|
use_amp: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.pretrained_path = None
|
self.pretrained_path = None
|
||||||
|
if not self.device or not is_torch_device_available(self.device):
|
||||||
|
auto_device = auto_select_torch_device()
|
||||||
|
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||||
|
self.device = auto_device.type
|
||||||
|
|
||||||
|
# Automatically deactivate AMP if necessary
|
||||||
|
if self.use_amp and not is_amp_available(self.device):
|
||||||
|
logging.warning(
|
||||||
|
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||||
|
)
|
||||||
|
self.use_amp = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
|
|
|
@ -1,5 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -13,7 +25,6 @@ from lerobot.common import envs
|
||||||
from lerobot.common.optim import OptimizerConfig
|
from lerobot.common.optim import OptimizerConfig
|
||||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||||
from lerobot.common.utils.hub import HubMixin
|
from lerobot.common.utils.hub import HubMixin
|
||||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
@ -35,10 +46,6 @@ class TrainPipelineConfig(HubMixin):
|
||||||
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
||||||
# regardless of what's provided with the training command at the time of resumption.
|
# regardless of what's provided with the training command at the time of resumption.
|
||||||
resume: bool = False
|
resume: bool = False
|
||||||
device: str | None = None # cuda | cpu | mp
|
|
||||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
|
||||||
# automatic gradient scaling is used.
|
|
||||||
use_amp: bool = False
|
|
||||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||||
# AND for the evaluation environments.
|
# AND for the evaluation environments.
|
||||||
seed: int | None = 1000
|
seed: int | None = 1000
|
||||||
|
@ -61,18 +68,6 @@ class TrainPipelineConfig(HubMixin):
|
||||||
self.checkpoint_path = None
|
self.checkpoint_path = None
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
if not self.device:
|
|
||||||
logging.warning("No device specified, trying to infer device automatically")
|
|
||||||
device = auto_select_torch_device()
|
|
||||||
self.device = device.type
|
|
||||||
|
|
||||||
# Automatically deactivate AMP if necessary
|
|
||||||
if self.use_amp and not is_amp_available(self.device):
|
|
||||||
logging.warning(
|
|
||||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
|
||||||
)
|
|
||||||
self.use_amp = False
|
|
||||||
|
|
||||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||||
policy_path = parser.get_path_arg("policy")
|
policy_path = parser.get_path_arg("policy")
|
||||||
if policy_path:
|
if policy_path:
|
||||||
|
|
|
@ -1,3 +1,16 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
# Note: We subclass str so that serialization is straightforward
|
# Note: We subclass str so that serialization is straightforward
|
||||||
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
|
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
|
@ -1,3 +1,16 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
This script configure a single motor at a time to a given ID and baudrate.
|
This script configure a single motor at a time to a given ID and baudrate.
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,16 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Utilities to control a robot.
|
Utilities to control a robot.
|
||||||
|
|
||||||
|
@ -254,7 +267,7 @@ def record(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
|
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
@ -285,8 +298,6 @@ def record(
|
||||||
episode_time_s=cfg.episode_time_s,
|
episode_time_s=cfg.episode_time_s,
|
||||||
display_cameras=cfg.display_cameras,
|
display_cameras=cfg.display_cameras,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
device=cfg.device,
|
|
||||||
use_amp=cfg.use_amp,
|
|
||||||
fps=cfg.fps,
|
fps=cfg.fps,
|
||||||
single_task=cfg.single_task,
|
single_task=cfg.single_task,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,3 +1,16 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Utilities to control a robot in simulation.
|
Utilities to control a robot in simulation.
|
||||||
|
|
||||||
|
|
|
@ -458,7 +458,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||||
logging.info(pformat(asdict(cfg)))
|
logging.info(pformat(asdict(cfg)))
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -470,14 +470,14 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||||
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||||
|
|
||||||
logging.info("Making policy.")
|
logging.info("Making policy.")
|
||||||
|
|
||||||
policy = make_policy(
|
policy = make_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
device=device,
|
|
||||||
env_cfg=cfg.env,
|
env_cfg=cfg.env,
|
||||||
)
|
)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||||
info = eval_policy(
|
info = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy,
|
policy,
|
||||||
|
|
|
@ -1,3 +1,16 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
|
@ -120,7 +120,7 @@ def train(cfg: TrainPipelineConfig):
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
@ -138,13 +138,12 @@ def train(cfg: TrainPipelineConfig):
|
||||||
logging.info("Creating policy")
|
logging.info("Creating policy")
|
||||||
policy = make_policy(
|
policy = make_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
device=device,
|
|
||||||
ds_meta=dataset.meta,
|
ds_meta=dataset.meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Creating optimizer and scheduler")
|
logging.info("Creating optimizer and scheduler")
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
|
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
||||||
|
|
||||||
step = 0 # number of policy updates (forward + backward + optim)
|
step = 0 # number of policy updates (forward + backward + optim)
|
||||||
|
|
||||||
|
@ -218,7 +217,7 @@ def train(cfg: TrainPipelineConfig):
|
||||||
cfg.optimizer.grad_clip_norm,
|
cfg.optimizer.grad_clip_norm,
|
||||||
grad_scaler=grad_scaler,
|
grad_scaler=grad_scaler,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
use_amp=cfg.use_amp,
|
use_amp=cfg.policy.use_amp,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||||
|
@ -249,7 +248,10 @@ def train(cfg: TrainPipelineConfig):
|
||||||
if cfg.env and is_eval_step:
|
if cfg.env and is_eval_step:
|
||||||
step_id = get_step_identifier(step, cfg.steps)
|
step_id = get_step_identifier(step, cfg.steps)
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
with (
|
||||||
|
torch.no_grad(),
|
||||||
|
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||||
|
):
|
||||||
eval_info = eval_policy(
|
eval_info = eval_policy(
|
||||||
eval_env,
|
eval_env,
|
||||||
policy,
|
policy,
|
||||||
|
|
|
@ -265,13 +265,25 @@ def main():
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tolerance-s",
|
||||||
|
type=float,
|
||||||
|
default=1e-4,
|
||||||
|
help=(
|
||||||
|
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
|
||||||
|
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
|
||||||
|
"If not given, defaults to 1e-4."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
kwargs = vars(args)
|
kwargs = vars(args)
|
||||||
repo_id = kwargs.pop("repo_id")
|
repo_id = kwargs.pop("repo_id")
|
||||||
root = kwargs.pop("root")
|
root = kwargs.pop("root")
|
||||||
|
tolerance_s = kwargs.pop("tolerance_s")
|
||||||
|
|
||||||
logging.info("Loading dataset")
|
logging.info("Loading dataset")
|
||||||
dataset = LeRobotDataset(repo_id, root=root)
|
dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
|
||||||
|
|
||||||
visualize_dataset(dataset, **vars(args))
|
visualize_dataset(dataset, **vars(args))
|
||||||
|
|
||||||
|
|
|
@ -234,7 +234,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||||
columns = []
|
columns = []
|
||||||
|
|
||||||
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
|
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
|
||||||
selected_columns.remove("timestamp")
|
selected_columns.remove("timestamp")
|
||||||
|
|
||||||
ignored_columns = []
|
ignored_columns = []
|
||||||
|
@ -446,15 +446,31 @@ def main():
|
||||||
help="Delete the output directory if it exists already.",
|
help="Delete the output directory if it exists already.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tolerance-s",
|
||||||
|
type=float,
|
||||||
|
default=1e-4,
|
||||||
|
help=(
|
||||||
|
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
|
||||||
|
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
|
||||||
|
"If not given, defaults to 1e-4."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
kwargs = vars(args)
|
kwargs = vars(args)
|
||||||
repo_id = kwargs.pop("repo_id")
|
repo_id = kwargs.pop("repo_id")
|
||||||
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
|
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
|
||||||
root = kwargs.pop("root")
|
root = kwargs.pop("root")
|
||||||
|
tolerance_s = kwargs.pop("tolerance_s")
|
||||||
|
|
||||||
dataset = None
|
dataset = None
|
||||||
if repo_id:
|
if repo_id:
|
||||||
dataset = LeRobotDataset(repo_id, root=root) if not load_from_hf_hub else get_dataset_info(repo_id)
|
dataset = (
|
||||||
|
LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
|
||||||
|
if not load_from_hf_hub
|
||||||
|
else get_dataset_info(repo_id)
|
||||||
|
)
|
||||||
|
|
||||||
visualize_dataset_html(dataset, **vars(args))
|
visualize_dataset_html(dataset, **vars(args))
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
homepage = "https://github.com/huggingface/lerobot"
|
homepage = "https://github.com/huggingface/lerobot"
|
||||||
issues = "https://github.com/huggingface/lerobot/issues"
|
issues = "https://github.com/huggingface/lerobot/issues"
|
||||||
|
@ -8,18 +22,19 @@ name = "lerobot"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||||
authors = [
|
authors = [
|
||||||
{name = "Rémi Cadène", email = "re.cadene@gmail.com"},
|
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
|
||||||
{name = "Simon Alibert", email = "alibert.sim@gmail.com"},
|
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
|
||||||
{name = "Alexander Soare", email = "alexander.soare159@gmail.com"},
|
{ name = "Alexander Soare", email = "alexander.soare159@gmail.com" },
|
||||||
{name = "Quentin Gallouédec", email = "quentin.gallouedec@ec-lyon.fr"},
|
{ name = "Quentin Gallouédec", email = "quentin.gallouedec@ec-lyon.fr" },
|
||||||
{name = "Adil Zouitine", email = "adilzouitinegm@gmail.com"},
|
{ name = "Adil Zouitine", email = "adilzouitinegm@gmail.com" },
|
||||||
{name = "Thomas Wolf", email = "thomaswolfcontact@gmail.com"},
|
{ name = "Thomas Wolf", email = "thomaswolfcontact@gmail.com" },
|
||||||
|
{ name = "Steven Palma", email = "imstevenpmwork@ieee.org" },
|
||||||
]
|
]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = {text = "Apache-2.0"}
|
license = { text = "Apache-2.0" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
keywords = ["robotics", "deep learning", "pytorch"]
|
keywords = ["robotics", "deep learning", "pytorch"]
|
||||||
classifiers=[
|
classifiers = [
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"Intended Audience :: Education",
|
"Intended Audience :: Education",
|
||||||
|
@ -38,10 +53,9 @@ dependencies = [
|
||||||
"einops>=0.8.0",
|
"einops>=0.8.0",
|
||||||
"flask>=3.0.3",
|
"flask>=3.0.3",
|
||||||
"gdown>=5.1.0",
|
"gdown>=5.1.0",
|
||||||
"gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
|
"gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
|
||||||
"h5py>=3.10.0",
|
"h5py>=3.10.0",
|
||||||
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
|
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
|
||||||
"hydra-core>=1.3.2",
|
|
||||||
"imageio[ffmpeg]>=2.34.0",
|
"imageio[ffmpeg]>=2.34.0",
|
||||||
"jsonlines>=4.0.0",
|
"jsonlines>=4.0.0",
|
||||||
"numba>=0.59.0",
|
"numba>=0.59.0",
|
||||||
|
@ -55,6 +69,7 @@ dependencies = [
|
||||||
"rerun-sdk>=0.21.0",
|
"rerun-sdk>=0.21.0",
|
||||||
"termcolor>=2.4.0",
|
"termcolor>=2.4.0",
|
||||||
"torch>=2.2.1",
|
"torch>=2.2.1",
|
||||||
|
"torchcodec>=0.2.1",
|
||||||
"torchvision>=0.21.0",
|
"torchvision>=0.21.0",
|
||||||
"wandb>=0.16.3",
|
"wandb>=0.16.3",
|
||||||
"zarr>=2.17.0",
|
"zarr>=2.17.0",
|
||||||
|
@ -63,7 +78,9 @@ dependencies = [
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"]
|
aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"]
|
||||||
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"]
|
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"]
|
||||||
dora = ["gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'"]
|
dora = [
|
||||||
|
"gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'",
|
||||||
|
]
|
||||||
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||||
|
@ -74,7 +91,7 @@ stretch = [
|
||||||
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
|
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
|
||||||
"pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
"pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
||||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||||
"pynput>=1.7.7"
|
"pynput>=1.7.7",
|
||||||
]
|
]
|
||||||
test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"]
|
test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"]
|
||||||
umi = ["imagecodecs>=2024.1.1"]
|
umi = ["imagecodecs>=2024.1.1"]
|
||||||
|
@ -87,30 +104,7 @@ requires-poetry = ">=2.1"
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 110
|
line-length = 110
|
||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
exclude = [
|
exclude = ["tests/artifacts/**/*.safetensors"]
|
||||||
"tests/data",
|
|
||||||
".bzr",
|
|
||||||
".direnv",
|
|
||||||
".eggs",
|
|
||||||
".git",
|
|
||||||
".git-rewrite",
|
|
||||||
".hg",
|
|
||||||
".mypy_cache",
|
|
||||||
".nox",
|
|
||||||
".pants.d",
|
|
||||||
".pytype",
|
|
||||||
".ruff_cache",
|
|
||||||
".svn",
|
|
||||||
".tox",
|
|
||||||
".venv",
|
|
||||||
"__pypackages__",
|
|
||||||
"_build",
|
|
||||||
"buck-out",
|
|
||||||
"build",
|
|
||||||
"dist",
|
|
||||||
"node_modules",
|
|
||||||
"venv",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||||
|
@ -128,8 +122,8 @@ skips = ["B101", "B311", "B404", "B603"]
|
||||||
|
|
||||||
[tool.typos]
|
[tool.typos]
|
||||||
default.extend-ignore-re = [
|
default.extend-ignore-re = [
|
||||||
"(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
|
"(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
|
||||||
"(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on" # spellchecker:<on|off>
|
"(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on", # spellchecker:<on|off>
|
||||||
]
|
]
|
||||||
default.extend-ignore-identifiers-re = [
|
default.extend-ignore-identifiers-re = [
|
||||||
# Add individual words here to ignore them
|
# Add individual words here to ignore them
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue