diff --git a/.dockerignore b/.dockerignore
index b8c1be15..4f074d44 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
# Misc
.git
tmp
diff --git a/.gitattributes b/.gitattributes
index 7da36424..44e16cf1 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
*.memmap filter=lfs diff=lfs merge=lfs -text
*.stl filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml
index 7cbed673..2fb23051 100644
--- a/.github/ISSUE_TEMPLATE/bug-report.yml
+++ b/.github/ISSUE_TEMPLATE/bug-report.yml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve LeRobot
body:
diff --git a/.github/workflows/build-docker-images.yml b/.github/workflows/build-docker-images.yml
index 3c63fa11..0cb11d57 100644
--- a/.github/workflows/build-docker-images.yml
+++ b/.github/workflows/build-docker-images.yml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
# Inspired by
# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml
name: Builds
diff --git a/.github/workflows/nightly-tests.yml b/.github/workflows/nightly-tests.yml
index 210a690c..adac9f20 100644
--- a/.github/workflows/nightly-tests.yml
+++ b/.github/workflows/nightly-tests.yml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
# Inspired by
# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml
name: Nightly
diff --git a/.github/workflows/pr_style_bot.yml b/.github/workflows/pr_style_bot.yml
deleted file mode 100644
index a34042c4..00000000
--- a/.github/workflows/pr_style_bot.yml
+++ /dev/null
@@ -1,125 +0,0 @@
-# Adapted from https://github.com/huggingface/diffusers/blob/main/.github/workflows/pr_style_bot.yml
-name: PR Style Bot
-
-on:
- issue_comment:
- types: [created]
-
-permissions:
- contents: write
- pull-requests: write
-
-jobs:
- run-style-bot:
- if: >
- contains(github.event.comment.body, '@bot /style') &&
- github.event.issue.pull_request != null
- runs-on: ubuntu-latest
-
- steps:
- - name: Extract PR details
- id: pr_info
- uses: actions/github-script@v6
- with:
- script: |
- const prNumber = context.payload.issue.number;
- const { data: pr } = await github.rest.pulls.get({
- owner: context.repo.owner,
- repo: context.repo.repo,
- pull_number: prNumber
- });
-
- // We capture both the branch ref and the "full_name" of the head repo
- // so that we can check out the correct repository & branch (including forks).
- core.setOutput("prNumber", prNumber);
- core.setOutput("headRef", pr.head.ref);
- core.setOutput("headRepoFullName", pr.head.repo.full_name);
-
- - name: Check out PR branch
- uses: actions/checkout@v4
- env:
- HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
- HEADREF: ${{ steps.pr_info.outputs.headRef }}
- with:
- persist-credentials: false
- # Instead of checking out the base repo, use the contributor's repo name
- repository: ${{ env.HEADREPOFULLNAME }}
- ref: ${{ env.HEADREF }}
- # You may need fetch-depth: 0 for being able to push
- fetch-depth: 0
- token: ${{ secrets.GITHUB_TOKEN }}
-
- - name: Debug
- env:
- HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
- HEADREF: ${{ steps.pr_info.outputs.headRef }}
- PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
- run: |
- echo "PR number: $PRNUMBER"
- echo "Head Ref: $HEADREF"
- echo "Head Repo Full Name: $HEADREPOFULLNAME"
-
- - name: Set up Python
- uses: actions/setup-python@v4
-
- - name: Get Ruff Version from pre-commit-config.yaml
- id: get-ruff-version
- run: |
- RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
- echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
-
- - name: Install Ruff
- env:
- RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
- run: python -m pip install "ruff==${RUFF_VERSION}"
-
- - name: Ruff check
- run: ruff check --fix
-
- - name: Ruff format
- run: ruff format
-
- - name: Commit and push changes
- id: commit_and_push
- env:
- HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
- HEADREF: ${{ steps.pr_info.outputs.headRef }}
- PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- run: |
- echo "HEADREPOFULLNAME: $HEADREPOFULLNAME, HEADREF: $HEADREF"
- # Configure git with the Actions bot user
- git config user.name "github-actions[bot]"
- git config user.email "github-actions[bot]@users.noreply.github.com"
-
- # Make sure your 'origin' remote is set to the contributor's fork
- git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/$HEADREPOFULLNAME.git"
-
- # If there are changes after running style/quality, commit them
- if [ -n "$(git status --porcelain)" ]; then
- git add .
- git commit -m "Apply style fixes"
- # Push to the original contributor's forked branch
- git push origin HEAD:$HEADREF
- echo "changes_pushed=true" >> $GITHUB_OUTPUT
- else
- echo "No changes to commit."
- echo "changes_pushed=false" >> $GITHUB_OUTPUT
- fi
-
- - name: Comment on PR with workflow run link
- if: steps.commit_and_push.outputs.changes_pushed == 'true'
- uses: actions/github-script@v6
- with:
- script: |
- const prNumber = parseInt(process.env.prNumber, 10);
- const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
-
- await github.rest.issues.createComment({
- owner: context.repo.owner,
- repo: context.repo.repo,
- issue_number: prNumber,
- body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
- });
- env:
- prNumber: ${{ steps.pr_info.outputs.prNumber }}
diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml
index f785d52f..332b543c 100644
--- a/.github/workflows/quality.yml
+++ b/.github/workflows/quality.yml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
name: Quality
on:
diff --git a/.github/workflows/test-docker-build.yml b/.github/workflows/test-docker-build.yml
index 3ee84a27..e77c570e 100644
--- a/.github/workflows/test-docker-build.yml
+++ b/.github/workflows/test-docker-build.yml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
# Inspired by
# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml
name: Test Dockerfiles
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 9c3f5756..3ef47887 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
name: Tests
on:
diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml
index 487ccea5..166e0590 100644
--- a/.github/workflows/trufflehog.yml
+++ b/.github/workflows/trufflehog.yml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
on:
push:
diff --git a/.gitignore b/.gitignore
index 0a0ffe10..da4b1089 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
# Logging
logs
tmp
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b921f4e1..21016efa 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,7 +1,22 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
exclude: ^(tests/data)
default_language_version:
python: python3.10
repos:
+ ##### Style / Misc. #####
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
@@ -14,7 +29,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/crate-ci/typos
- rev: v1.29.10
+ rev: v1.30.0
hooks:
- id: typos
args: [--force-exclude]
@@ -23,16 +38,24 @@ repos:
hooks:
- id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.9.6
+ rev: v0.9.9
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
+
+ ##### Security #####
- repo: https://github.com/gitleaks/gitleaks
- rev: v8.23.3
+ rev: v8.24.0
hooks:
- id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
- rev: v1.3.1
+ rev: v1.4.1
hooks:
- id: zizmor
+ - repo: https://github.com/PyCQA/bandit
+ rev: 1.8.3
+ hooks:
+ - id: bandit
+ args: ["-c", "pyproject.toml"]
+ additional_dependencies: ["bandit[toml]"]
diff --git a/Makefile b/Makefile
index 772da320..c82483cc 100644
--- a/Makefile
+++ b/Makefile
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
.PHONY: tests
PYTHON_PATH := $(shell which python)
@@ -33,6 +47,7 @@ test-act-ete-train:
--policy.dim_model=64 \
--policy.n_action_steps=20 \
--policy.chunk_size=20 \
+ --policy.device=$(DEVICE) \
--env.type=aloha \
--env.episode_length=5 \
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
@@ -47,7 +62,6 @@ test-act-ete-train:
--save_checkpoint=true \
--log_freq=1 \
--wandb.enable=false \
- --device=$(DEVICE) \
--output_dir=tests/outputs/act/
test-act-ete-train-resume:
@@ -58,11 +72,11 @@ test-act-ete-train-resume:
test-act-ete-eval:
python lerobot/scripts/eval.py \
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
+ --policy.device=$(DEVICE) \
--env.type=aloha \
--env.episode_length=5 \
--eval.n_episodes=1 \
- --eval.batch_size=1 \
- --device=$(DEVICE)
+ --eval.batch_size=1
test-diffusion-ete-train:
python lerobot/scripts/train.py \
@@ -70,6 +84,7 @@ test-diffusion-ete-train:
--policy.down_dims='[64,128,256]' \
--policy.diffusion_step_embed_dim=32 \
--policy.num_inference_steps=10 \
+ --policy.device=$(DEVICE) \
--env.type=pusht \
--env.episode_length=5 \
--dataset.repo_id=lerobot/pusht \
@@ -84,21 +99,21 @@ test-diffusion-ete-train:
--save_freq=2 \
--log_freq=1 \
--wandb.enable=false \
- --device=$(DEVICE) \
--output_dir=tests/outputs/diffusion/
test-diffusion-ete-eval:
python lerobot/scripts/eval.py \
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
+ --policy.device=$(DEVICE) \
--env.type=pusht \
--env.episode_length=5 \
--eval.n_episodes=1 \
- --eval.batch_size=1 \
- --device=$(DEVICE)
+ --eval.batch_size=1
test-tdmpc-ete-train:
python lerobot/scripts/train.py \
--policy.type=tdmpc \
+ --policy.device=$(DEVICE) \
--env.type=xarm \
--env.task=XarmLift-v0 \
--env.episode_length=5 \
@@ -114,15 +129,14 @@ test-tdmpc-ete-train:
--save_freq=2 \
--log_freq=1 \
--wandb.enable=false \
- --device=$(DEVICE) \
--output_dir=tests/outputs/tdmpc/
test-tdmpc-ete-eval:
python lerobot/scripts/eval.py \
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
+ --policy.device=$(DEVICE) \
--env.type=xarm \
--env.episode_length=5 \
--env.task=XarmLift-v0 \
--eval.n_episodes=1 \
- --eval.batch_size=1 \
- --device=$(DEVICE)
+ --eval.batch_size=1
diff --git a/README.md b/README.md
index 59929341..b16e3469 100644
--- a/README.md
+++ b/README.md
@@ -23,15 +23,24 @@
Want to take it to the next level? Make your SO-100 mobile by building LeKiwi!
+
Check out the LeKiwi tutorial and bring your robot to life on wheels.
+
+
@@ -375,3 +384,6 @@ Additionally, if you are using any of the particular policy architecture, pretra
year={2024}
}
```
+## Star History
+
+[](https://star-history.com/#huggingface/lerobot&Timeline)
diff --git a/examples/1_load_lerobot_dataset.py b/examples/1_load_lerobot_dataset.py
index 96c104b6..c374a375 100644
--- a/examples/1_load_lerobot_dataset.py
+++ b/examples/1_load_lerobot_dataset.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py
index 0a7b8deb..edbbad38 100644
--- a/examples/2_evaluate_pretrained_policy.py
+++ b/examples/2_evaluate_pretrained_policy.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
@@ -30,7 +44,7 @@ pretrained_policy_path = "lerobot/diffusion_pusht"
# OR a path to a local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
-policy = DiffusionPolicy.from_pretrained(pretrained_policy_path, map_location=device)
+policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
# Initialize evaluation environment to render two observation types:
# an image of the scene and state/position of the agent. The environment
diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py
index cf5d4d3e..6c3af54e 100644
--- a/examples/3_train_policy.py
+++ b/examples/3_train_policy.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
Once you have trained a model with this script, you can try to evaluate it on
@@ -85,7 +99,7 @@ def main():
done = False
while not done:
for batch in dataloader:
- batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
+ batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
diff --git a/examples/advanced/1_add_image_transforms.py b/examples/advanced/1_add_image_transforms.py
index 882710e3..f1460926 100644
--- a/examples/advanced/1_add_image_transforms.py
+++ b/examples/advanced/1_add_image_transforms.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
diff --git a/examples/advanced/2_calculate_validation_loss.py b/examples/advanced/2_calculate_validation_loss.py
index 6f234719..47b4dd02 100644
--- a/examples/advanced/2_calculate_validation_loss.py
+++ b/examples/advanced/2_calculate_validation_loss.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
This technique can be useful for debugging and testing purposes, as well as identifying whether a policy
diff --git a/examples/port_datasets/pusht_zarr.py b/examples/port_datasets/pusht_zarr.py
index eac6f63d..ea2e8b60 100644
--- a/examples/port_datasets/pusht_zarr.py
+++ b/examples/port_datasets/pusht_zarr.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import shutil
from pathlib import Path
diff --git a/lerobot/common/cameras/intel/camera_realsense.py b/lerobot/common/cameras/intel/camera_realsense.py
index c7017d2b..13080904 100644
--- a/lerobot/common/cameras/intel/camera_realsense.py
+++ b/lerobot/common/cameras/intel/camera_realsense.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
This file contains utilities for recording frames from Intel Realsense cameras.
"""
@@ -99,7 +113,7 @@ def save_images_from_cameras(
camera = RealSenseCamera(config)
camera.connect()
print(
- f"RealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
+ f"RealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})"
)
cameras.append(camera)
@@ -209,9 +223,20 @@ class RealSenseCamera(Camera):
self.serial_number = self.find_serial_number_from_name(config.name)
else:
self.serial_number = config.serial_number
+
+ # Store the raw (capture) resolution from the config.
+ self.capture_width = config.width
+ self.capture_height = config.height
+
+ # If rotated by ±90, swap width and height.
+ if config.rotation in [-90, 90]:
+ self.width = config.height
+ self.height = config.width
+ else:
+ self.width = config.width
+ self.height = config.height
+
self.fps = config.fps
- self.width = config.width
- self.height = config.height
self.channels = config.channels
self.color_mode = config.color_mode
self.use_depth = config.use_depth
@@ -231,7 +256,6 @@ class RealSenseCamera(Camera):
else:
import cv2
- # TODO(alibets): Do we keep original width/height or do we define them after rotation?
self.rotation = None
if config.rotation == -90:
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
@@ -267,15 +291,19 @@ class RealSenseCamera(Camera):
config = rs.config()
config.enable_device(str(self.serial_number))
- if self.fps and self.width and self.height:
+ if self.fps and self.capture_width and self.capture_height:
# TODO(rcadene): can we set rgb8 directly?
- config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
+ config.enable_stream(
+ rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
+ )
else:
config.enable_stream(rs.stream.color)
if self.use_depth:
- if self.fps and self.width and self.height:
- config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
+ if self.fps and self.capture_width and self.capture_height:
+ config.enable_stream(
+ rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
+ )
else:
config.enable_stream(rs.stream.depth)
@@ -313,18 +341,18 @@ class RealSenseCamera(Camera):
raise OSError(
f"Can't set {self.fps=} for RealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
)
- if self.width is not None and self.width != actual_width:
+ if self.capture_width is not None and self.capture_width != actual_width:
raise OSError(
- f"Can't set {self.width=} for RealSenseCamera({self.serial_number}). Actual value is {actual_width}."
+ f"Can't set {self.capture_width=} for RealSenseCamera({self.serial_number}). Actual value is {actual_width}."
)
- if self.height is not None and self.height != actual_height:
+ if self.capture_height is not None and self.capture_height != actual_height:
raise OSError(
- f"Can't set {self.height=} for RealSenseCamera({self.serial_number}). Actual value is {actual_height}."
+ f"Can't set {self.capture_height=} for RealSenseCamera({self.serial_number}). Actual value is {actual_height}."
)
self.fps = round(actual_fps)
- self.width = round(actual_width)
- self.height = round(actual_height)
+ self.capture_width = round(actual_width)
+ self.capture_height = round(actual_height)
self.is_connected = True
@@ -370,7 +398,7 @@ class RealSenseCamera(Camera):
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
h, w, _ = color_image.shape
- if h != self.height or w != self.width:
+ if h != self.capture_height or w != self.capture_width:
raise OSError(
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
)
@@ -392,7 +420,7 @@ class RealSenseCamera(Camera):
depth_map = np.asanyarray(depth_frame.get_data())
h, w = depth_map.shape
- if h != self.height or w != self.width:
+ if h != self.capture_height or w != self.capture_width:
raise OSError(
f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
)
diff --git a/lerobot/common/cameras/intel/configuration_realsense.py b/lerobot/common/cameras/intel/configuration_realsense.py
index 5dae89b9..66bb1b4f 100644
--- a/lerobot/common/cameras/intel/configuration_realsense.py
+++ b/lerobot/common/cameras/intel/configuration_realsense.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass
from ..configs import CameraConfig
diff --git a/lerobot/common/cameras/opencv/camera_opencv.py b/lerobot/common/cameras/opencv/camera_opencv.py
index b35380b4..a7d5a0f3 100644
--- a/lerobot/common/cameras/opencv/camera_opencv.py
+++ b/lerobot/common/cameras/opencv/camera_opencv.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
"""
@@ -131,8 +145,8 @@ def save_images_from_cameras(
camera = OpenCVCamera(config)
camera.connect()
print(
- f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
- f"height={camera.height}, color_mode={camera.color_mode})"
+ f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.capture_width}, "
+ f"height={camera.capture_height}, color_mode={camera.color_mode})"
)
cameras.append(camera)
@@ -231,9 +245,19 @@ class OpenCVCamera(Camera):
else:
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
+ # Store the raw (capture) resolution from the config.
+ self.capture_width = config.width
+ self.capture_height = config.height
+
+ # If rotated by ±90, swap width and height.
+ if config.rotation in [-90, 90]:
+ self.width = config.height
+ self.height = config.width
+ else:
+ self.width = config.width
+ self.height = config.height
+
self.fps = config.fps
- self.width = config.width
- self.height = config.height
self.channels = config.channels
self.color_mode = config.color_mode
self.mock = config.mock
@@ -250,7 +274,6 @@ class OpenCVCamera(Camera):
else:
import cv2
- # TODO(aliberts): Do we keep original width/height or do we define them after rotation?
self.rotation = None
if config.rotation == -90:
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
@@ -272,10 +295,20 @@ class OpenCVCamera(Camera):
# when other threads are used to save the images.
cv2.setNumThreads(1)
+ backend = (
+ cv2.CAP_V4L2
+ if platform.system() == "Linux"
+ else cv2.CAP_DSHOW
+ if platform.system() == "Windows"
+ else cv2.CAP_AVFOUNDATION
+ if platform.system() == "Darwin"
+ else cv2.CAP_ANY
+ )
+
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
# First create a temporary camera trying to access `camera_index`,
# and verify it is a valid camera by calling `isOpened`.
- tmp_camera = cv2.VideoCapture(camera_idx)
+ tmp_camera = cv2.VideoCapture(camera_idx, backend)
is_camera_open = tmp_camera.isOpened()
# Release camera to make it accessible for `find_camera_indices`
tmp_camera.release()
@@ -298,14 +331,14 @@ class OpenCVCamera(Camera):
# Secondly, create the camera that will be used downstream.
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
# needs to be re-created.
- self.camera = cv2.VideoCapture(camera_idx)
+ self.camera = cv2.VideoCapture(camera_idx, backend)
if self.fps is not None:
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
- if self.width is not None:
- self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
- if self.height is not None:
- self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
+ if self.capture_width is not None:
+ self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.capture_width)
+ if self.capture_height is not None:
+ self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.capture_height)
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
@@ -317,19 +350,22 @@ class OpenCVCamera(Camera):
raise OSError(
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
)
- if self.width is not None and not math.isclose(self.width, actual_width, rel_tol=1e-3):
+ if self.capture_width is not None and not math.isclose(
+ self.capture_width, actual_width, rel_tol=1e-3
+ ):
raise OSError(
- f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
+ f"Can't set {self.capture_width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
)
- if self.height is not None and not math.isclose(self.height, actual_height, rel_tol=1e-3):
+ if self.capture_height is not None and not math.isclose(
+ self.capture_height, actual_height, rel_tol=1e-3
+ ):
raise OSError(
- f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
+ f"Can't set {self.capture_height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
)
self.fps = round(actual_fps)
- self.width = round(actual_width)
- self.height = round(actual_height)
-
+ self.capture_width = round(actual_width)
+ self.capture_height = round(actual_height)
self.is_connected = True
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
@@ -370,7 +406,7 @@ class OpenCVCamera(Camera):
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
h, w, _ = color_image.shape
- if h != self.height or w != self.width:
+ if h != self.capture_height or w != self.capture_width:
raise OSError(
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
)
diff --git a/lerobot/common/constants.py b/lerobot/common/constants.py
index 9e83bf33..22d7568e 100644
--- a/lerobot/common/constants.py
+++ b/lerobot/common/constants.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
# keys
import os
from pathlib import Path
diff --git a/lerobot/common/datasets/backward_compatibility.py b/lerobot/common/datasets/backward_compatibility.py
index d1b8926a..cf8e31c4 100644
--- a/lerobot/common/datasets/backward_compatibility.py
+++ b/lerobot/common/datasets/backward_compatibility.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import packaging.version
V2_MESSAGE = """
diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py
index f716e7bc..e4919788 100644
--- a/lerobot/common/datasets/lerobot_dataset.py
+++ b/lerobot/common/datasets/lerobot_dataset.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import contextlib
import logging
import shutil
from pathlib import Path
@@ -27,6 +28,7 @@ import torch.utils
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
+from huggingface_hub.errors import RevisionNotFoundError
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
@@ -517,6 +519,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
branch: str | None = None,
tags: list | None = None,
license: str | None = "apache-2.0",
+ tag_version: bool = True,
push_videos: bool = True,
private: bool = False,
allow_patterns: list[str] | str | None = None,
@@ -562,6 +565,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
+ if tag_version:
+ with contextlib.suppress(RevisionNotFoundError):
+ hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
+ hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
+
def pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py
index bc2bf20d..b1fd50cb 100644
--- a/lerobot/common/datasets/utils.py
+++ b/lerobot/common/datasets/utils.py
@@ -31,6 +31,7 @@ import packaging.version
import torch
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
+from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
@@ -325,6 +326,19 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
)
hub_versions = get_repo_versions(repo_id)
+ if not hub_versions:
+ raise RevisionNotFoundError(
+ f"""Your dataset must be tagged with a codebase version.
+ Assuming _version_ is the codebase_version value in the info.json, you can run this:
+ ```python
+ from huggingface_hub import HfApi
+
+ hub_api = HfApi()
+ hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
+ ```
+ """
+ )
+
if target_version in hub_versions:
return f"v{target_version}"
diff --git a/lerobot/common/datasets/v21/_remove_language_instruction.py b/lerobot/common/datasets/v21/_remove_language_instruction.py
index dd4604cf..643ddd3f 100644
--- a/lerobot/common/datasets/v21/_remove_language_instruction.py
+++ b/lerobot/common/datasets/v21/_remove_language_instruction.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import logging
import traceback
from pathlib import Path
diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py
index 20bda75b..176d16d0 100644
--- a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py
+++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
2.1. It will:
@@ -57,7 +71,7 @@ def convert_dataset(
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root)
- dataset.push_to_hub(branch=branch, allow_patterns="meta/")
+ dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
# delete old stats.json file
if (dataset.root / STATS_PATH).is_file:
diff --git a/lerobot/common/datasets/v21/convert_stats.py b/lerobot/common/datasets/v21/convert_stats.py
index cbf584b7..4a20b427 100644
--- a/lerobot/common/datasets/v21/convert_stats.py
+++ b/lerobot/common/datasets/v21/convert_stats.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
diff --git a/lerobot/common/envs/__init__.py b/lerobot/common/envs/__init__.py
index a583ffc5..4977d11d 100644
--- a/lerobot/common/envs/__init__.py
+++ b/lerobot/common/envs/__init__.py
@@ -1 +1,15 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py
index d000d1dd..6de3cf03 100644
--- a/lerobot/common/envs/configs.py
+++ b/lerobot/common/envs/configs.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import abc
from dataclasses import dataclass, field
diff --git a/lerobot/common/motors/configs.py b/lerobot/common/motors/configs.py
index 37b781f9..0bfbaf83 100644
--- a/lerobot/common/motors/configs.py
+++ b/lerobot/common/motors/configs.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import abc
from dataclasses import dataclass
diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py
index 3c87c12a..55f6069b 100644
--- a/lerobot/common/motors/dynamixel/dynamixel.py
+++ b/lerobot/common/motors/dynamixel/dynamixel.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import enum
import logging
import math
diff --git a/lerobot/common/motors/dynamixel/dynamixel_calibration.py b/lerobot/common/motors/dynamixel/dynamixel_calibration.py
index 72056e49..9426d1f8 100644
--- a/lerobot/common/motors/dynamixel/dynamixel_calibration.py
+++ b/lerobot/common/motors/dynamixel/dynamixel_calibration.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Logic to calibrate a robot arm built with dynamixel motors"""
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py
index 494c834b..c183967a 100644
--- a/lerobot/common/motors/feetech/feetech.py
+++ b/lerobot/common/motors/feetech/feetech.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import enum
import logging
import math
diff --git a/lerobot/common/motors/feetech/feetech_calibration.py b/lerobot/common/motors/feetech/feetech_calibration.py
index 4f4a6dca..778dd0b2 100644
--- a/lerobot/common/motors/feetech/feetech_calibration.py
+++ b/lerobot/common/motors/feetech/feetech_calibration.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Logic to calibrate a robot arm built with feetech motors"""
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
diff --git a/lerobot/common/motors/utils.py b/lerobot/common/motors/utils.py
index 1de3bdad..bfda5bcc 100644
--- a/lerobot/common/motors/utils.py
+++ b/lerobot/common/motors/utils.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .configs import MotorsBusConfig
from .motors_bus import MotorsBus
diff --git a/lerobot/common/optim/__init__.py b/lerobot/common/optim/__init__.py
index e1e65966..de2c4c99 100644
--- a/lerobot/common/optim/__init__.py
+++ b/lerobot/common/optim/__init__.py
@@ -1 +1,15 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .optimizers import OptimizerConfig as OptimizerConfig
diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py
index 2e4486ef..b73ba5f4 100644
--- a/lerobot/common/policies/__init__.py
+++ b/lerobot/common/policies/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py
index cd440f7a..5d2f6cb5 100644
--- a/lerobot/common/policies/factory.py
+++ b/lerobot/common/policies/factory.py
@@ -16,7 +16,6 @@
import logging
-import torch
from torch import nn
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
@@ -76,7 +75,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
def make_policy(
cfg: PreTrainedConfig,
- device: str | torch.device,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
) -> PreTrainedPolicy:
@@ -88,7 +86,6 @@ def make_policy(
Args:
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
be loaded with the weights from that path.
- device (str): the device to load the policy onto.
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
@@ -96,7 +93,7 @@ def make_policy(
Raises:
ValueError: Either ds_meta or env and env_cfg must be provided.
- NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility)
+ NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
Returns:
PreTrainedPolicy: _description_
@@ -111,7 +108,7 @@ def make_policy(
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
# slower than running natively on MPS.
- if cfg.type == "vqbet" and str(device) == "mps":
+ if cfg.type == "vqbet" and cfg.device == "mps":
raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend."
@@ -145,7 +142,7 @@ def make_policy(
# Make a fresh policy.
policy = policy_cls(**kwargs)
- policy.to(device)
+ policy.to(cfg.device)
assert isinstance(policy, nn.Module)
# policy = torch.compile(policy, mode="reduce-overhead")
diff --git a/lerobot/common/policies/pi0/configuration_pi0.py b/lerobot/common/policies/pi0/configuration_pi0.py
index 8d2eedf6..8c7cc130 100644
--- a/lerobot/common/policies/pi0/configuration_pi0.py
+++ b/lerobot/common/policies/pi0/configuration_pi0.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
@@ -76,6 +90,7 @@ class PI0Config(PreTrainedConfig):
def __post_init__(self):
super().__post_init__()
+ # TODO(Steven): Validate device and amp? in all policy configs?
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
diff --git a/lerobot/common/policies/pi0/conversion_scripts/benchmark.py b/lerobot/common/policies/pi0/conversion_scripts/benchmark.py
index 31bd1b66..cb3c0e9b 100644
--- a/lerobot/common/policies/pi0/conversion_scripts/benchmark.py
+++ b/lerobot/common/policies/pi0/conversion_scripts/benchmark.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@@ -31,7 +45,7 @@ def main():
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
- policy = make_policy(cfg, device, ds_meta=dataset.meta)
+ policy = make_policy(cfg, ds_meta=dataset.meta)
# policy = torch.compile(policy, mode="reduce-overhead")
diff --git a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py
index 8b2e1c66..6bd7c91f 100644
--- a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py
+++ b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import pickle
from pathlib import Path
@@ -87,7 +101,7 @@ def main():
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
- policy = make_policy(cfg, device, dataset_meta)
+ policy = make_policy(cfg, dataset_meta)
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
# loss_dict["loss"].backward()
diff --git a/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py b/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py
index 8e35d0d4..8835da31 100644
--- a/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py
+++ b/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from transformers import GemmaConfig, PaliGemmaConfig
diff --git a/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py
index dd8622dd..73ff506f 100644
--- a/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py
+++ b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
Convert pi0 parameters from Jax to Pytorch
diff --git a/lerobot/common/policies/pi0/flex_attention.py b/lerobot/common/policies/pi0/flex_attention.py
index 38a5b597..35628cdd 100644
--- a/lerobot/common/policies/pi0/flex_attention.py
+++ b/lerobot/common/policies/pi0/flex_attention.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import torch
import torch.nn.functional as F # noqa: N812
from packaging.version import Version
diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py
index 555a86bd..ede199a6 100644
--- a/lerobot/common/policies/pi0/modeling_pi0.py
+++ b/lerobot/common/policies/pi0/modeling_pi0.py
@@ -313,7 +313,7 @@ class PI0Policy(PreTrainedPolicy):
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.prepare_action(batch)
- actions_is_pad = batch.get("actions_id_pad")
+ actions_is_pad = batch.get("actions_is_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
diff --git a/lerobot/common/policies/pi0/paligemma_with_expert.py b/lerobot/common/policies/pi0/paligemma_with_expert.py
index 08c36c11..76e2ce60 100644
--- a/lerobot/common/policies/pi0/paligemma_with_expert.py
+++ b/lerobot/common/policies/pi0/paligemma_with_expert.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import List, Optional, Union
import torch
diff --git a/lerobot/common/policies/pretrained.py b/lerobot/common/policies/pretrained.py
index 1729dfb0..da4ef157 100644
--- a/lerobot/common/policies/pretrained.py
+++ b/lerobot/common/policies/pretrained.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import abc
import logging
import os
@@ -73,7 +86,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
- map_location: str = "cpu",
strict: bool = False,
**kwargs,
) -> T:
@@ -98,7 +110,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
- policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
+ policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
else:
try:
model_file = hf_hub_download(
@@ -112,13 +124,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
token=token,
local_files_only=local_files_only,
)
- policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
+ policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e
- policy.to(map_location)
+ policy.to(config.device)
policy.eval()
return policy
diff --git a/lerobot/common/robots/lekiwi/README.md b/lerobot/common/robots/lekiwi/README.md
index 224a1854..4b811417 100644
--- a/lerobot/common/robots/lekiwi/README.md
+++ b/lerobot/common/robots/lekiwi/README.md
@@ -23,6 +23,9 @@ Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains th
Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
+### Wired version
+If you have the **wired** LeKiwi version you can skip the installation of the Raspberry Pi and setting up SSH. You can also run all commands directly on your PC for both the LeKiwi scripts and the leader arm scripts for teleoperating.
+
## B. Install software on Pi
Now we have to setup the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board.
@@ -246,6 +249,110 @@ class LeKiwiRobotConfig(RobotConfig):
}
)
+ teleop_keys: dict[str, str] = field(
+ default_factory=lambda: {
+ # Movement
+ "forward": "w",
+ "backward": "s",
+ "left": "a",
+ "right": "d",
+ "rotate_left": "z",
+ "rotate_right": "x",
+ # Speed control
+ "speed_up": "r",
+ "speed_down": "f",
+ # quit teleop
+ "quit": "q",
+ }
+ )
+
+ mock: bool = False
+```
+
+## Wired version
+
+For the wired LeKiwi version your configured IP address should refer to your own laptop (127.0.0.1), because leader arm and LeKiwi are in this case connected to own laptop. Below and example configuration for this wired setup:
+```python
+@RobotConfig.register_subclass("lekiwi")
+@dataclass
+class LeKiwiRobotConfig(RobotConfig):
+ # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
+ # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
+ # the number of motors in your follower arms.
+ max_relative_target: int | None = None
+
+ # Network Configuration
+ ip: str = "127.0.0.1"
+ port: int = 5555
+ video_port: int = 5556
+
+ cameras: dict[str, CameraConfig] = field(
+ default_factory=lambda: {
+ "front": OpenCVCameraConfig(
+ camera_index=0, fps=30, width=640, height=480, rotation=90
+ ),
+ "wrist": OpenCVCameraConfig(
+ camera_index=1, fps=30, width=640, height=480, rotation=180
+ ),
+ }
+ )
+
+ calibration_dir: str = ".cache/calibration/lekiwi"
+
+ leader_arms: dict[str, MotorsBusConfig] = field(
+ default_factory=lambda: {
+ "main": FeetechMotorsBusConfig(
+ port="/dev/tty.usbmodem585A0077581",
+ motors={
+ # name: (index, model)
+ "shoulder_pan": [1, "sts3215"],
+ "shoulder_lift": [2, "sts3215"],
+ "elbow_flex": [3, "sts3215"],
+ "wrist_flex": [4, "sts3215"],
+ "wrist_roll": [5, "sts3215"],
+ "gripper": [6, "sts3215"],
+ },
+ ),
+ }
+ )
+
+ follower_arms: dict[str, MotorsBusConfig] = field(
+ default_factory=lambda: {
+ "main": FeetechMotorsBusConfig(
+ port="/dev/tty.usbmodem58760431061",
+ motors={
+ # name: (index, model)
+ "shoulder_pan": [1, "sts3215"],
+ "shoulder_lift": [2, "sts3215"],
+ "elbow_flex": [3, "sts3215"],
+ "wrist_flex": [4, "sts3215"],
+ "wrist_roll": [5, "sts3215"],
+ "gripper": [6, "sts3215"],
+ "left_wheel": (7, "sts3215"),
+ "back_wheel": (8, "sts3215"),
+ "right_wheel": (9, "sts3215"),
+ },
+ ),
+ }
+ )
+
+ teleop_keys: dict[str, str] = field(
+ default_factory=lambda: {
+ # Movement
+ "forward": "w",
+ "backward": "s",
+ "left": "a",
+ "right": "d",
+ "rotate_left": "z",
+ "rotate_right": "x",
+ # Speed control
+ "speed_up": "r",
+ "speed_down": "f",
+ # quit teleop
+ "quit": "q",
+ }
+ )
+
mock: bool = False
```
@@ -272,6 +379,9 @@ python lerobot/scripts/control_robot.py \
--control.arms='["main_follower"]'
```
+### Wired version
+If you have the **wired** LeKiwi version please run all commands including this calibration command on your laptop.
+
### Calibrate leader arm
Then to calibrate the leader arm (which is attached to the laptop/pc). You will need to move the leader arm to these positions sequentially:
@@ -326,6 +436,9 @@ You should see on your laptop something like this: ```[INFO] Connected to remote
> [!TIP]
> If you use a different keyboard you can change the keys for each command in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
+### Wired version
+If you have the **wired** LeKiwi version please run all commands including both these teleoperation commands on your laptop.
+
## Troubleshoot communication
If you are having trouble connecting to the Mobile SO100, follow these steps to diagnose and resolve the issue.
@@ -364,6 +477,13 @@ Make sure the configuration file on both your laptop/pc and the Raspberry Pi is
# G. Record a dataset
Once you're familiar with teleoperation, you can record your first dataset with LeKiwi.
+To start the program on LeKiwi, SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
+```bash
+python lerobot/scripts/control_robot.py \
+ --robot.type=lekiwi \
+ --control.type=remote_robot
+```
+
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
@@ -374,8 +494,7 @@ Store your Hugging Face repository name in a variable to run these commands:
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
-
-Record 2 episodes and upload your dataset to the hub:
+On your laptop then run this command to record 2 episodes and upload your dataset to the hub:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=lekiwi \
@@ -393,6 +512,9 @@ python lerobot/scripts/control_robot.py \
Note: You can resume recording by adding `--control.resume=true`.
+### Wired version
+If you have the **wired** LeKiwi version please run all commands including both these record dataset commands on your laptop.
+
# H. Visualize a dataset
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
diff --git a/lerobot/common/robots/lekiwi/lekiwi_remote.py b/lerobot/common/robots/lekiwi/lekiwi_remote.py
index 80b522be..1643cd5e 100644
--- a/lerobot/common/robots/lekiwi/lekiwi_remote.py
+++ b/lerobot/common/robots/lekiwi/lekiwi_remote.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import base64
import json
import threading
diff --git a/lerobot/common/robots/manipulator.py b/lerobot/common/robots/manipulator.py
index 69ce212e..29f56788 100644
--- a/lerobot/common/robots/manipulator.py
+++ b/lerobot/common/robots/manipulator.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Contains logic to instantiate a robot, read information from its motors and cameras,
and send orders to its motors.
"""
diff --git a/lerobot/common/robots/mobile_manipulator.py b/lerobot/common/robots/mobile_manipulator.py
index 38612885..87c85245 100644
--- a/lerobot/common/robots/mobile_manipulator.py
+++ b/lerobot/common/robots/mobile_manipulator.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import base64
import json
import os
@@ -393,21 +407,19 @@ class MobileManipulator:
for name in self.leader_arms:
pos = self.leader_arms[name].read("Present_Position")
pos_tensor = torch.from_numpy(pos).float()
- # Instead of pos_tensor.item(), use tolist() to convert the entire tensor to a list
arm_positions.extend(pos_tensor.tolist())
- # (The rest of your code for generating wheel commands remains unchanged)
- x_cmd = 0.0 # m/s forward/backward
- y_cmd = 0.0 # m/s lateral
+ y_cmd = 0.0 # m/s forward/backward
+ x_cmd = 0.0 # m/s lateral
theta_cmd = 0.0 # deg/s rotation
if self.pressed_keys["forward"]:
- x_cmd += xy_speed
- if self.pressed_keys["backward"]:
- x_cmd -= xy_speed
- if self.pressed_keys["left"]:
y_cmd += xy_speed
- if self.pressed_keys["right"]:
+ if self.pressed_keys["backward"]:
y_cmd -= xy_speed
+ if self.pressed_keys["left"]:
+ x_cmd += xy_speed
+ if self.pressed_keys["right"]:
+ x_cmd -= xy_speed
if self.pressed_keys["rotate_left"]:
theta_cmd += theta_speed
if self.pressed_keys["rotate_right"]:
@@ -585,8 +597,8 @@ class MobileManipulator:
# Create the body velocity vector [x, y, theta_rad].
velocity_vector = np.array([x_cmd, y_cmd, theta_rad])
- # Define the wheel mounting angles with a -90° offset.
- angles = np.radians(np.array([240, 120, 0]) - 90)
+ # Define the wheel mounting angles (defined from y axis cw)
+ angles = np.radians(np.array([300, 180, 60]))
# Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed.
# The third column (base_radius) accounts for the effect of rotation.
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
@@ -642,8 +654,8 @@ class MobileManipulator:
# Compute each wheel’s linear speed (m/s) from its angular speed.
wheel_linear_speeds = wheel_radps * wheel_radius
- # Define the wheel mounting angles with a -90° offset.
- angles = np.radians(np.array([240, 120, 0]) - 90)
+ # Define the wheel mounting angles (defined from y axis cw)
+ angles = np.radians(np.array([300, 180, 60]))
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
# Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds.
diff --git a/lerobot/common/robots/so100/README.md b/lerobot/common/robots/so100/README.md
index b39a0239..b63eb146 100644
--- a/lerobot/common/robots/so100/README.md
+++ b/lerobot/common/robots/so100/README.md
@@ -4,8 +4,8 @@
- [A. Source the parts](#a-source-the-parts)
- [B. Install LeRobot](#b-install-lerobot)
- - [C. Configure the motors](#c-configure-the-motors)
- - [D. Assemble the arms](#d-assemble-the-arms)
+ - [C. Configure the Motors](#c-configure-the-motors)
+ - [D. Step-by-Step Assembly Instructions](#d-step-by-step-assembly-instructions)
- [E. Calibrate](#e-calibrate)
- [F. Teleoperate](#f-teleoperate)
- [G. Record a dataset](#g-record-a-dataset)
@@ -70,6 +70,7 @@ conda install -y -c conda-forge "opencv>=4.10.0"
```
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
+
## C. Configure the motors
> [!NOTE]
@@ -98,22 +99,22 @@ Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem5
```
Finding all available ports for the MotorBus.
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
-Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
+Remove the usb cable from your MotorsBus and press Enter when done.
[...Disconnect leader arm and press Enter...]
-The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
+The port of this MotorsBus is /dev/tty.usbmodem575E0031751
Reconnect the usb cable.
```
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
```
Finding all available ports for the MotorBus.
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
-Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
+Remove the usb cable from your MotorsBus and press Enter when done.
[...Disconnect follower arm and press Enter...]
-The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
+The port of this MotorsBus is /dev/tty.usbmodem575E0032081
Reconnect the usb cable.
```
@@ -221,19 +222,13 @@ Redo the process for all your motors until ID 6. Do the same for the 6 motors of
Follow the video for removing gears. You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
-#### c. Add motor horn to all 12 motors
+## D. Step-by-Step Assembly Instructions
-
-Video adding motor horn
+**Step 1: Clean Parts**
+- Remove all support material from the 3D-printed parts.
+---
-
-
-
-
-Follow the video for adding the motor horn. For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
-Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
-
-## D. Assemble the arms
+### Additional Guidance
Video assembling arms
@@ -242,7 +237,211 @@ Try to avoid rotating the motor while doing so to keep position 2048 set during
-Follow the video for assembling the arms. It is important to insert the cables into the motor that is being assembled before you assemble the motor into the arm! Inserting the cables beforehand is much easier than doing this afterward. The first arm should take a bit more than 1 hour to assemble, but once you get used to it, you can do it under 1 hour for the second arm.
+**Note:**
+This video provides visual guidance for assembling the arms, but it doesn't specify when or how to do the wiring. Inserting the cables beforehand is much easier than doing it afterward. The first arm may take a bit more than 1 hour to assemble, but once you get used to it, you can assemble the second arm in under 1 hour.
+
+---
+
+### First Motor
+
+**Step 2: Insert Wires**
+- Insert two wires into the first motor.
+
+
+
+**Step 3: Install in Base**
+- Place the first motor into the base.
+
+
+
+**Step 4: Secure Motor**
+- Fasten the motor with 4 screws. Two from the bottom and two from top.
+
+**Step 5: Attach Motor Holder**
+- Slide over the first motor holder and fasten it using two screws (one on each side).
+
+
+
+**Step 6: Attach Motor Horns**
+- Install both motor horns, securing the top horn with a screw. Try not to move the motor position when attaching the motor horn, especially for the leader arms, where we removed the gears.
+
+
+
+ Video adding motor horn
+
+
+
+**Step 7: Attach Shoulder Part**
+- Route one wire to the back of the robot and the other to the left or in photo towards you (see photo).
+- Attach the shoulder part.
+
+
+
+**Step 8: Secure Shoulder**
+- Tighten the shoulder part with 4 screws on top and 4 on the bottom
+*(access bottom holes by turning the shoulder).*
+
+---
+
+### Second Motor Assembly
+
+**Step 9: Install Motor 2**
+- Slide the second motor in from the top and link the wire from motor 1 to motor 2.
+
+
+
+**Step 10: Attach Shoulder Holder**
+- Add the shoulder motor holder.
+- Ensure the wire from motor 1 to motor 2 goes behind the holder while the other wire is routed upward (see photo).
+- This part can be tight to assemble, you can use a workbench like the image or a similar setup to push the part around the motor.
+
+
+
+
+
+
+
+**Step 11: Secure Motor 2**
+- Fasten the second motor with 4 screws.
+
+**Step 12: Attach Motor Horn**
+- Attach both motor horns to motor 2, again use the horn screw.
+
+**Step 13: Attach Base**
+- Install the base attachment using 2 screws.
+
+
+
+**Step 14: Attach Upper Arm**
+- Attach the upper arm with 4 screws on each side.
+
+
+
+---
+
+### Third Motor Assembly
+
+**Step 15: Install Motor 3**
+- Route the motor cable from motor 2 through the cable holder to motor 3, then secure motor 3 with 4 screws.
+
+**Step 16: Attach Motor Horn**
+- Attach both motor horns to motor 3 and secure one again with a horn screw.
+
+
+
+**Step 17: Attach Forearm**
+- Connect the forearm to motor 3 using 4 screws on each side.
+
+
+
+---
+
+### Fourth Motor Assembly
+
+**Step 18: Install Motor 4**
+- Slide in motor 4, attach the cable from motor 3, and secure the cable in its holder with a screw.
+
+
+
+
+
+
+**Step 19: Attach Motor Holder 4**
+- Install the fourth motor holder (a tight fit). Ensure one wire is routed upward and the wire from motor 3 is routed downward (see photo).
+
+
+
+**Step 20: Secure Motor 4 & Attach Horn**
+- Fasten motor 4 with 4 screws and attach its motor horns, use for one a horn screw.
+
+
+
+---
+
+### Wrist Assembly
+
+**Step 21: Install Motor 5**
+- Insert motor 5 into the wrist holder and secure it with 2 front screws.
+
+
+
+**Step 22: Attach Wrist**
+- Connect the wire from motor 4 to motor 5. And already insert the other wire for the gripper.
+- Secure the wrist to motor 4 using 4 screws on both sides.
+
+
+
+**Step 23: Attach Wrist Horn**
+- Install only one motor horn on the wrist motor and secure it with a horn screw.
+
+
+
+---
+
+### Follower Configuration
+
+**Step 24: Attach Gripper**
+- Attach the gripper to motor 5.
+
+
+
+**Step 25: Install Gripper Motor**
+- Insert the gripper motor, connect the motor wire from motor 5 to motor 6, and secure it with 3 screws on each side.
+
+
+
+**Step 26: Attach Gripper Horn & Claw**
+- Attach the motor horns and again use a horn screw.
+- Install the gripper claw and secure it with 4 screws on both sides.
+
+
+
+**Step 27: Mount Controller**
+- Attach the motor controller on the back.
+
+
+
+
+
+
+*Assembly complete – proceed to Leader arm assembly.*
+
+---
+
+### Leader Configuration
+
+For the leader configuration, perform **Steps 1–23**. Make sure that you removed the motor gears from the motors.
+
+**Step 24: Attach Leader Holder**
+- Mount the leader holder onto the wrist and secure it with a screw.
+
+
+
+**Step 25: Attach Handle**
+- Attach the handle to motor 5 using 4 screws.
+
+
+
+**Step 26: Install Gripper Motor**
+- Insert the gripper motor, secure it with 3 screws on each side, attach a motor horn using a horn screw, and connect the motor wire.
+
+
+
+**Step 27: Attach Trigger**
+- Attach the follower trigger with 4 screws.
+
+
+
+**Step 28: Mount Controller**
+- Attach the motor controller on the back.
+
+
+
+
+
+
+*Assembly complete – proceed to calibration.*
+
## E. Calibrate
diff --git a/lerobot/common/utils/control_utils.py b/lerobot/common/utils/control_utils.py
index 7627abac..50d20109 100644
--- a/lerobot/common/utils/control_utils.py
+++ b/lerobot/common/utils/control_utils.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
########################################################################################
# Utilities
########################################################################################
@@ -18,6 +32,7 @@ from termcolor import colored
from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot
+from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.robots.utils import Robot
from lerobot.common.utils.robot_utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, has_method
@@ -179,8 +194,6 @@ def record_episode(
episode_time_s,
display_cameras,
policy,
- device,
- use_amp,
fps,
single_task,
):
@@ -191,8 +204,6 @@ def record_episode(
dataset=dataset,
events=events,
policy=policy,
- device=device,
- use_amp=use_amp,
fps=fps,
teleoperate=policy is None,
single_task=single_task,
@@ -207,9 +218,7 @@ def control_loop(
display_cameras=False,
dataset: LeRobotDataset | None = None,
events=None,
- policy=None,
- device: torch.device | str | None = None,
- use_amp: bool | None = None,
+ policy: PreTrainedPolicy = None,
fps: int | None = None,
single_task: str | None = None,
):
@@ -232,9 +241,6 @@ def control_loop(
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
- if isinstance(device, str):
- device = get_safe_torch_device(device)
-
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
@@ -246,7 +252,9 @@ def control_loop(
observation = robot.capture_observation()
if policy is not None:
- pred_action = predict_action(observation, policy, device, use_amp)
+ pred_action = predict_action(
+ observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
+ )
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = robot.send_action(pred_action)
diff --git a/lerobot/common/utils/hub.py b/lerobot/common/utils/hub.py
index 63fcf918..df7435c0 100644
--- a/lerobot/common/utils/hub.py
+++ b/lerobot/common/utils/hub.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Type, TypeVar
diff --git a/lerobot/common/utils/robot_utils.py b/lerobot/common/utils/robot_utils.py
index 593773b5..e6c0cfe6 100644
--- a/lerobot/common/utils/robot_utils.py
+++ b/lerobot/common/utils/robot_utils.py
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import platform
import time
diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py
index d0c12b30..563a7b81 100644
--- a/lerobot/common/utils/utils.py
+++ b/lerobot/common/utils/utils.py
@@ -17,6 +17,7 @@ import logging
import os
import os.path as osp
import platform
+import subprocess
from copy import copy
from datetime import datetime, timezone
from pathlib import Path
@@ -50,8 +51,10 @@ def auto_select_torch_device() -> torch.device:
return torch.device("cpu")
+# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
+ try_device = str(try_device)
match try_device:
case "cuda":
assert torch.cuda.is_available()
@@ -84,6 +87,7 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
def is_torch_device_available(try_device: str) -> bool:
+ try_device = str(try_device) # Ensure try_device is a string
if try_device == "cuda":
return torch.cuda.is_available()
elif try_device == "mps":
@@ -91,7 +95,7 @@ def is_torch_device_available(try_device: str) -> bool:
elif try_device == "cpu":
return True
else:
- raise ValueError(f"Unknown device '{try_device}.")
+ raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
def is_amp_available(device: str):
@@ -165,23 +169,31 @@ def capture_timestamp_utc():
def say(text, blocking=False):
- # Check if mac, linux, or windows.
- if platform.system() == "Darwin":
- cmd = f'say "{text}"'
- if not blocking:
- cmd += " &"
- elif platform.system() == "Linux":
- cmd = f'spd-say "{text}"'
- if blocking:
- cmd += " --wait"
- elif platform.system() == "Windows":
- # TODO(rcadene): Make blocking option work for Windows
- cmd = (
- 'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
- f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
- )
+ system = platform.system()
- os.system(cmd)
+ if system == "Darwin":
+ cmd = ["say", text]
+
+ elif system == "Linux":
+ cmd = ["spd-say", text]
+ if blocking:
+ cmd.append("--wait")
+
+ elif system == "Windows":
+ cmd = [
+ "PowerShell",
+ "-Command",
+ "Add-Type -AssemblyName System.Speech; "
+ f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')",
+ ]
+
+ else:
+ raise RuntimeError("Unsupported operating system for text-to-speech.")
+
+ if blocking:
+ subprocess.run(cmd, check=True)
+ else:
+ subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
def log_say(text, play_sounds, blocking=False):
diff --git a/lerobot/configs/control.py b/lerobot/configs/control.py
index 109b0ba9..75e0a093 100644
--- a/lerobot/configs/control.py
+++ b/lerobot/configs/control.py
@@ -1,14 +1,25 @@
-import logging
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass
from pathlib import Path
import draccus
from lerobot.common.robots import RobotConfig
-from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.train import TrainPipelineConfig
@dataclass
@@ -43,11 +54,6 @@ class RecordControlConfig(ControlConfig):
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
policy: PreTrainedConfig | None = None
- # TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
- device: str | None = None # cuda | cpu | mps
- # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
- # automatic gradient scaling is used.
- use_amp: bool | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int | None = None
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
@@ -90,27 +96,6 @@ class RecordControlConfig(ControlConfig):
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
- # When no device or use_amp are given, use the one from training config.
- if self.device is None or self.use_amp is None:
- train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
- if self.device is None:
- self.device = train_cfg.device
- if self.use_amp is None:
- self.use_amp = train_cfg.use_amp
-
- # Automatically switch to available device if necessary
- if not is_torch_device_available(self.device):
- auto_device = auto_select_torch_device()
- logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
- self.device = auto_device
-
- # Automatically deactivate AMP if necessary
- if self.use_amp and not is_amp_available(self.device):
- logging.warning(
- f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
- )
- self.use_amp = False
-
@ControlConfig.register_subclass("replay")
@dataclass
diff --git a/lerobot/configs/eval.py b/lerobot/configs/eval.py
index 11873352..16b35291 100644
--- a/lerobot/configs/eval.py
+++ b/lerobot/configs/eval.py
@@ -1,14 +1,26 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import datetime as dt
import logging
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.common import envs, policies # noqa: F401
-from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs import parser
from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.train import TrainPipelineConfig
@dataclass
@@ -21,11 +33,6 @@ class EvalPipelineConfig:
policy: PreTrainedConfig | None = None
output_dir: Path | None = None
job_name: str | None = None
- # TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
- device: str | None = None # cuda | cpu | mps
- # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
- # automatic gradient scaling is used.
- use_amp: bool = False
seed: int | None = 1000
def __post_init__(self):
@@ -36,27 +43,6 @@ class EvalPipelineConfig:
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
- # When no device or use_amp are given, use the one from training config.
- if self.device is None or self.use_amp is None:
- train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
- if self.device is None:
- self.device = train_cfg.device
- if self.use_amp is None:
- self.use_amp = train_cfg.use_amp
-
- # Automatically switch to available device if necessary
- if not is_torch_device_available(self.device):
- auto_device = auto_select_torch_device()
- logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
- self.device = auto_device
-
- # Automatically deactivate AMP if necessary
- if self.use_amp and not is_amp_available(self.device):
- logging.warning(
- f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
- )
- self.use_amp = False
-
else:
logging.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
@@ -73,11 +59,6 @@ class EvalPipelineConfig:
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/eval") / eval_dir
- if self.device is None:
- raise ValueError("Set one of the following device: cuda, cpu or mps")
- elif self.device == "cuda" and self.use_amp is None:
- raise ValueError("Set 'use_amp' to True or False.")
-
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
diff --git a/lerobot/configs/parser.py b/lerobot/configs/parser.py
index ee784877..39e31515 100644
--- a/lerobot/configs/parser.py
+++ b/lerobot/configs/parser.py
@@ -1,4 +1,19 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
import inspect
+import pkgutil
import sys
from argparse import ArgumentError
from functools import wraps
@@ -10,6 +25,7 @@ import draccus
from lerobot.common.utils.utils import has_method
PATH_KEY = "path"
+PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
draccus.set_config_type("json")
@@ -45,6 +61,86 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
return None
+def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
+ """Parse plugin-related arguments from command-line arguments.
+
+ This function extracts arguments from command-line arguments that match a specified suffix pattern.
+ It processes arguments in the format '--key=value' and returns them as a dictionary.
+
+ Args:
+ plugin_arg_suffix (str): The suffix to identify plugin-related arguments.
+ cli_args (Sequence[str]): A sequence of command-line arguments to parse.
+
+ Returns:
+ dict: A dictionary containing the parsed plugin arguments where:
+ - Keys are the argument names (with '--' prefix removed if present)
+ - Values are the corresponding argument values
+
+ Example:
+ >>> args = ['--env.discover_packages_path=my_package',
+ ... '--other_arg=value']
+ >>> parse_plugin_args('discover_packages_path', args)
+ {'env.discover_packages_path': 'my_package'}
+ """
+ plugin_args = {}
+ for arg in args:
+ if "=" in arg and plugin_arg_suffix in arg:
+ key, value = arg.split("=", 1)
+ # Remove leading '--' if present
+ if key.startswith("--"):
+ key = key[2:]
+ plugin_args[key] = value
+ return plugin_args
+
+
+class PluginLoadError(Exception):
+ """Raised when a plugin fails to load."""
+
+
+def load_plugin(plugin_path: str) -> None:
+ """Load and initialize a plugin from a given Python package path.
+
+ This function attempts to load a plugin by importing its package and any submodules.
+ Plugin registration is expected to happen during package initialization, i.e. when
+ the package is imported the gym environment should be registered and the config classes
+ registered with their parents using the `register_subclass` decorator.
+
+ Args:
+ plugin_path (str): The Python package path to the plugin (e.g. "mypackage.plugins.myplugin")
+
+ Raises:
+ PluginLoadError: If the plugin cannot be loaded due to import errors or if the package path is invalid.
+
+ Examples:
+ >>> load_plugin("external_plugin.core") # Loads plugin from external package
+
+ Notes:
+ - The plugin package should handle its own registration during import
+ - All submodules in the plugin package will be imported
+ - Implementation follows the plugin discovery pattern from Python packaging guidelines
+
+ See Also:
+ https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/
+ """
+ try:
+ package_module = importlib.import_module(plugin_path, __package__)
+ except (ImportError, ModuleNotFoundError) as e:
+ raise PluginLoadError(
+ f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
+ ) from e
+
+ def iter_namespace(ns_pkg):
+ return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
+
+ try:
+ for _finder, pkg_name, _ispkg in iter_namespace(package_module):
+ importlib.import_module(pkg_name)
+ except ImportError as e:
+ raise PluginLoadError(
+ f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
+ ) from e
+
+
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
return parse_arg(f"{field_name}.{PATH_KEY}", args)
@@ -92,10 +188,13 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
def wrap(config_path: Path | None = None):
"""
- HACK: Similar to draccus.wrap but does two additional things:
+ HACK: Similar to draccus.wrap but does three additional things:
- Will remove '.path' arguments from CLI in order to process them later on.
- If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will
initialize it from there to allow to fetch configs from the hub directly
+ - Will load plugins specified in the CLI arguments. These plugins will typically register
+ their own subclasses of config classes, so that draccus can find the right class to instantiate
+ from the CLI '.type' arguments
"""
def wrapper_outer(fn):
@@ -108,6 +207,14 @@ def wrap(config_path: Path | None = None):
args = args[1:]
else:
cli_args = sys.argv[1:]
+ plugin_args = parse_plugin_args(PLUGIN_DISCOVERY_SUFFIX, cli_args)
+ for plugin_cli_arg, plugin_path in plugin_args.items():
+ try:
+ load_plugin(plugin_path)
+ except PluginLoadError as e:
+ # add the relevant CLI arg to the error message
+ raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
+ cli_args = filter_arg(plugin_cli_arg, cli_args)
config_path_cli = parse_arg("config_path", cli_args)
if has_method(argtype, "__get_path_fields__"):
path_fields = argtype.__get_path_fields__()
diff --git a/lerobot/configs/policies.py b/lerobot/configs/policies.py
index 9b5a7c5c..022d1fb5 100644
--- a/lerobot/configs/policies.py
+++ b/lerobot/configs/policies.py
@@ -1,4 +1,18 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import abc
+import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
@@ -12,6 +26,7 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot.common.optim.optimizers import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin
+from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
# Generic variable that is either PreTrainedConfig or a subclass thereof
@@ -40,8 +55,24 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
+ device: str | None = None # cuda | cpu | mp
+ # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
+ # automatic gradient scaling is used.
+ use_amp: bool = False
+
def __post_init__(self):
self.pretrained_path = None
+ if not self.device or not is_torch_device_available(self.device):
+ auto_device = auto_select_torch_device()
+ logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
+ self.device = auto_device.type
+
+ # Automatically deactivate AMP if necessary
+ if self.use_amp and not is_amp_available(self.device):
+ logging.warning(
+ f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
+ )
+ self.use_amp = False
@property
def type(self) -> str:
diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py
index 464c11f9..2b147a5b 100644
--- a/lerobot/configs/train.py
+++ b/lerobot/configs/train.py
@@ -1,5 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import datetime as dt
-import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
@@ -13,7 +25,6 @@ from lerobot.common import envs
from lerobot.common.optim import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin
-from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig
@@ -35,10 +46,6 @@ class TrainPipelineConfig(HubMixin):
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
# regardless of what's provided with the training command at the time of resumption.
resume: bool = False
- device: str | None = None # cuda | cpu | mp
- # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
- # automatic gradient scaling is used.
- use_amp: bool = False
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: int | None = 1000
@@ -61,18 +68,6 @@ class TrainPipelineConfig(HubMixin):
self.checkpoint_path = None
def validate(self):
- if not self.device:
- logging.warning("No device specified, trying to infer device automatically")
- device = auto_select_torch_device()
- self.device = device.type
-
- # Automatically deactivate AMP if necessary
- if self.use_amp and not is_amp_available(self.device):
- logging.warning(
- f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
- )
- self.use_amp = False
-
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
if policy_path:
diff --git a/lerobot/configs/types.py b/lerobot/configs/types.py
index 0ca45a19..6b3d92e8 100644
--- a/lerobot/configs/types.py
+++ b/lerobot/configs/types.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
# Note: We subclass str so that serialization is straightforward
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
from dataclasses import dataclass
diff --git a/lerobot/scripts/configure_motor.py b/lerobot/scripts/configure_motor.py
index 1a55c6fc..8b6cd2b0 100644
--- a/lerobot/scripts/configure_motor.py
+++ b/lerobot/scripts/configure_motor.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
This script configure a single motor at a time to a given ID and baudrate.
diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py
index 6db5fa56..ce8f0948 100644
--- a/lerobot/scripts/control_robot.py
+++ b/lerobot/scripts/control_robot.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
Utilities to control a robot.
@@ -254,7 +267,7 @@ def record(
)
# Load pretrained policy
- policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
+ policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
if not robot.is_connected:
robot.connect()
@@ -285,8 +298,6 @@ def record(
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
policy=policy,
- device=cfg.device,
- use_amp=cfg.use_amp,
fps=cfg.fps,
single_task=cfg.single_task,
)
diff --git a/lerobot/scripts/control_sim_robot.py b/lerobot/scripts/control_sim_robot.py
index c222cbbc..b0dc9630 100644
--- a/lerobot/scripts/control_sim_robot.py
+++ b/lerobot/scripts/control_sim_robot.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
Utilities to control a robot in simulation.
diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py
index a4f79afc..d7a4201f 100644
--- a/lerobot/scripts/eval.py
+++ b/lerobot/scripts/eval.py
@@ -454,11 +454,11 @@ def _compile_episode_data(
@parser.wrap()
-def eval(cfg: EvalPipelineConfig):
+def eval_main(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg)))
# Check device is available
- device = get_safe_torch_device(cfg.device, log=True)
+ device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -470,14 +470,14 @@ def eval(cfg: EvalPipelineConfig):
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Making policy.")
+
policy = make_policy(
cfg=cfg.policy,
- device=device,
env_cfg=cfg.env,
)
policy.eval()
- with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
+ with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy(
env,
policy,
@@ -499,4 +499,4 @@ def eval(cfg: EvalPipelineConfig):
if __name__ == "__main__":
init_logging()
- eval()
+ eval_main()
diff --git a/lerobot/scripts/find_motors_bus_port.py b/lerobot/scripts/find_motors_bus_port.py
index 67b92ad7..68f2315d 100644
--- a/lerobot/scripts/find_motors_bus_port.py
+++ b/lerobot/scripts/find_motors_bus_port.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import os
import time
from pathlib import Path
diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py
index e36c697a..f2b1e29e 100644
--- a/lerobot/scripts/train.py
+++ b/lerobot/scripts/train.py
@@ -120,7 +120,7 @@ def train(cfg: TrainPipelineConfig):
set_seed(cfg.seed)
# Check device is available
- device = get_safe_torch_device(cfg.device, log=True)
+ device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -138,13 +138,12 @@ def train(cfg: TrainPipelineConfig):
logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
- device=device,
ds_meta=dataset.meta,
)
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
- grad_scaler = GradScaler(device, enabled=cfg.use_amp)
+ grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
@@ -218,7 +217,7 @@ def train(cfg: TrainPipelineConfig):
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
- use_amp=cfg.use_amp,
+ use_amp=cfg.policy.use_amp,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -249,7 +248,10 @@ def train(cfg: TrainPipelineConfig):
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
- with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
+ with (
+ torch.no_grad(),
+ torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
+ ):
eval_info = eval_policy(
eval_env,
policy,
diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py
index ac91f0c8..f944144a 100644
--- a/lerobot/scripts/visualize_dataset_html.py
+++ b/lerobot/scripts/visualize_dataset_html.py
@@ -158,7 +158,7 @@ def run_server(
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
- episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
+ episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
@@ -194,7 +194,7 @@ def run_server(
]
response = requests.get(
- f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl"
+ f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
)
response.raise_for_status()
# Split into lines and parse each line as JSON
@@ -218,6 +218,7 @@ def run_server(
videos_info=videos_info,
episode_data_csv_str=episode_data_csv_str,
columns=columns,
+ ignored_columns=ignored_columns,
)
app.run(host=host, port=port)
@@ -233,9 +234,17 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
- selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
+ selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
selected_columns.remove("timestamp")
+ ignored_columns = []
+ for column_name in selected_columns:
+ shape = dataset.features[column_name]["shape"]
+ shape_dim = len(shape)
+ if shape_dim > 1:
+ selected_columns.remove(column_name)
+ ignored_columns.append(column_name)
+
# init header of csv with state and action names
header = ["timestamp"]
@@ -245,16 +254,17 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
if isinstance(dataset, LeRobotDataset)
else dataset.features[column_name].shape[0]
)
- header += [f"{column_name}_{i}" for i in range(dim_state)]
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
else:
- column_names = [f"motor_{i}" for i in range(dim_state)]
+ column_names = [f"{column_name}_{i}" for i in range(dim_state)]
columns.append({"key": column_name, "value": column_names})
+ header += column_names
+
selected_columns.insert(0, "timestamp")
if isinstance(dataset, LeRobotDataset):
@@ -290,7 +300,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
csv_writer.writerows(rows)
csv_string = csv_buffer.getvalue()
- return csv_string, columns
+ return csv_string, columns, ignored_columns
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
@@ -317,7 +327,9 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
def get_dataset_info(repo_id: str) -> IterableNamespace:
- response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json")
+ response = requests.get(
+ f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
+ )
response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json()
dataset_info["repo_id"] = repo_id
diff --git a/lerobot/templates/visualize_dataset_template.html b/lerobot/templates/visualize_dataset_template.html
index 08de3e3d..cf9d40f1 100644
--- a/lerobot/templates/visualize_dataset_template.html
+++ b/lerobot/templates/visualize_dataset_template.html
@@ -14,21 +14,7 @@
- {
- // Use the space bar to play and pause, instead of default action (e.g. scrolling)
- const { keyCode, key } = e;
- if (keyCode === 32 || key === ' ') {
- e.preventDefault();
- $refs.btnPause.classList.contains('hidden') ? $refs.btnPlay.click() : $refs.btnPause.click();
- }else if (key === 'ArrowDown' || key === 'ArrowUp'){
- const nextEpisodeId = key === 'ArrowDown' ? {{ episode_id }} + 1 : {{ episode_id }} - 1;
- const lowestEpisodeId = {{ episodes }}.at(0);
- const highestEpisodeId = {{ episodes }}.at(-1);
- if(nextEpisodeId >= lowestEpisodeId && nextEpisodeId <= highestEpisodeId){
- window.location.href = `./episode_${nextEpisodeId}`;
- }
- }
-}">
+
+ Columns {{ ignored_columns }} are NOT shown since the visualizer currently does not support 2D or 3D data.
+
+ {% endif %}
+
@@ -279,7 +306,6 @@
videosKeys: {{ videos_info | map(attribute='filename') | list | tojson }},
videosKeysSelected: [],
columns: {{ columns | tojson }},
- rowLabels: {{ columns | tojson }}.reduce((colA, colB) => colA.value.length > colB.value.length ? colA : colB).value,
// alpine initialization
init() {
@@ -452,6 +478,68 @@
}
};
}
+
+ document.addEventListener('alpine:init', () => {
+ // Episode pagination component
+ Alpine.data('episodePagination', () => ({
+ episodes: {{ episodes }},
+ pageSize: 100,
+ page: 1,
+
+ init() {
+ // Find which page contains the current episode_id
+ const currentEpisodeId = {{ episode_id }};
+ const episodeIndex = this.episodes.indexOf(currentEpisodeId);
+ if (episodeIndex !== -1) {
+ this.page = Math.floor(episodeIndex / this.pageSize) + 1;
+ }
+ },
+
+ get totalPages() {
+ return Math.ceil(this.episodes.length / this.pageSize);
+ },
+
+ get paginatedEpisodes() {
+ const start = (this.page - 1) * this.pageSize;
+ const end = start + this.pageSize;
+ return this.episodes.slice(start, end);
+ },
+
+ nextPage() {
+ if (this.page < this.totalPages) {
+ this.page++;
+ }
+ },
+
+ prevPage() {
+ if (this.page > 1) {
+ this.page--;
+ }
+ }
+ }));
+ });
+
+
+
diff --git a/media/lekiwi/kiwi.webp b/media/lekiwi/kiwi.webp
new file mode 100644
index 00000000..2dd7d925
Binary files /dev/null and b/media/lekiwi/kiwi.webp differ
diff --git a/media/tutorial/img1.jpg b/media/tutorial/img1.jpg
new file mode 100644
index 00000000..c16fbf5e
Binary files /dev/null and b/media/tutorial/img1.jpg differ
diff --git a/media/tutorial/img10.jpg b/media/tutorial/img10.jpg
new file mode 100644
index 00000000..925ed918
Binary files /dev/null and b/media/tutorial/img10.jpg differ
diff --git a/media/tutorial/img11.jpg b/media/tutorial/img11.jpg
new file mode 100644
index 00000000..ac040312
Binary files /dev/null and b/media/tutorial/img11.jpg differ
diff --git a/media/tutorial/img12.jpg b/media/tutorial/img12.jpg
new file mode 100644
index 00000000..dd500afb
Binary files /dev/null and b/media/tutorial/img12.jpg differ
diff --git a/media/tutorial/img13.jpg b/media/tutorial/img13.jpg
new file mode 100644
index 00000000..99c6836f
Binary files /dev/null and b/media/tutorial/img13.jpg differ
diff --git a/media/tutorial/img14.jpg b/media/tutorial/img14.jpg
new file mode 100644
index 00000000..8b544e24
Binary files /dev/null and b/media/tutorial/img14.jpg differ
diff --git a/media/tutorial/img15.jpg b/media/tutorial/img15.jpg
new file mode 100644
index 00000000..76cee552
Binary files /dev/null and b/media/tutorial/img15.jpg differ
diff --git a/media/tutorial/img16.jpg b/media/tutorial/img16.jpg
new file mode 100644
index 00000000..363b729d
Binary files /dev/null and b/media/tutorial/img16.jpg differ
diff --git a/media/tutorial/img17.jpg b/media/tutorial/img17.jpg
new file mode 100644
index 00000000..dc211bc6
Binary files /dev/null and b/media/tutorial/img17.jpg differ
diff --git a/media/tutorial/img18.jpg b/media/tutorial/img18.jpg
new file mode 100644
index 00000000..c9732b65
Binary files /dev/null and b/media/tutorial/img18.jpg differ
diff --git a/media/tutorial/img19.jpg b/media/tutorial/img19.jpg
new file mode 100644
index 00000000..25c5f0e3
Binary files /dev/null and b/media/tutorial/img19.jpg differ
diff --git a/media/tutorial/img2.jpg b/media/tutorial/img2.jpg
new file mode 100644
index 00000000..47d3671c
Binary files /dev/null and b/media/tutorial/img2.jpg differ
diff --git a/media/tutorial/img20.jpg b/media/tutorial/img20.jpg
new file mode 100644
index 00000000..effe9c96
Binary files /dev/null and b/media/tutorial/img20.jpg differ
diff --git a/media/tutorial/img21.jpg b/media/tutorial/img21.jpg
new file mode 100644
index 00000000..0acc5194
Binary files /dev/null and b/media/tutorial/img21.jpg differ
diff --git a/media/tutorial/img22.jpg b/media/tutorial/img22.jpg
new file mode 100644
index 00000000..3f223a8b
Binary files /dev/null and b/media/tutorial/img22.jpg differ
diff --git a/media/tutorial/img23.jpg b/media/tutorial/img23.jpg
new file mode 100644
index 00000000..b9c411f4
Binary files /dev/null and b/media/tutorial/img23.jpg differ
diff --git a/media/tutorial/img24.jpg b/media/tutorial/img24.jpg
new file mode 100644
index 00000000..4011d190
Binary files /dev/null and b/media/tutorial/img24.jpg differ
diff --git a/media/tutorial/img25.jpg b/media/tutorial/img25.jpg
new file mode 100644
index 00000000..727dbadf
Binary files /dev/null and b/media/tutorial/img25.jpg differ
diff --git a/media/tutorial/img26.jpg b/media/tutorial/img26.jpg
new file mode 100644
index 00000000..ae38e980
Binary files /dev/null and b/media/tutorial/img26.jpg differ
diff --git a/media/tutorial/img27.jpg b/media/tutorial/img27.jpg
new file mode 100644
index 00000000..628b8a84
Binary files /dev/null and b/media/tutorial/img27.jpg differ
diff --git a/media/tutorial/img28.jpg b/media/tutorial/img28.jpg
new file mode 100644
index 00000000..e9a7fb5d
Binary files /dev/null and b/media/tutorial/img28.jpg differ
diff --git a/media/tutorial/img29.jpg b/media/tutorial/img29.jpg
new file mode 100644
index 00000000..78210b29
Binary files /dev/null and b/media/tutorial/img29.jpg differ
diff --git a/media/tutorial/img3.jpg b/media/tutorial/img3.jpg
new file mode 100644
index 00000000..bf9b7bca
Binary files /dev/null and b/media/tutorial/img3.jpg differ
diff --git a/media/tutorial/img30.jpg b/media/tutorial/img30.jpg
new file mode 100644
index 00000000..0fa59bba
Binary files /dev/null and b/media/tutorial/img30.jpg differ
diff --git a/media/tutorial/img31.jpg b/media/tutorial/img31.jpg
new file mode 100644
index 00000000..1409d2f0
Binary files /dev/null and b/media/tutorial/img31.jpg differ
diff --git a/media/tutorial/img32.jpg b/media/tutorial/img32.jpg
new file mode 100644
index 00000000..74dbee5f
Binary files /dev/null and b/media/tutorial/img32.jpg differ
diff --git a/media/tutorial/img4.jpg b/media/tutorial/img4.jpg
new file mode 100644
index 00000000..9d155c16
Binary files /dev/null and b/media/tutorial/img4.jpg differ
diff --git a/media/tutorial/img5.jpg b/media/tutorial/img5.jpg
new file mode 100644
index 00000000..afd3c428
Binary files /dev/null and b/media/tutorial/img5.jpg differ
diff --git a/media/tutorial/img6.jpg b/media/tutorial/img6.jpg
new file mode 100644
index 00000000..c669d37d
Binary files /dev/null and b/media/tutorial/img6.jpg differ
diff --git a/media/tutorial/img7.jpg b/media/tutorial/img7.jpg
new file mode 100644
index 00000000..3f6c3501
Binary files /dev/null and b/media/tutorial/img7.jpg differ
diff --git a/media/tutorial/img8.jpg b/media/tutorial/img8.jpg
new file mode 100644
index 00000000..009102cb
Binary files /dev/null and b/media/tutorial/img8.jpg differ
diff --git a/media/tutorial/img9.jpg b/media/tutorial/img9.jpg
new file mode 100644
index 00000000..06a4f58e
Binary files /dev/null and b/media/tutorial/img9.jpg differ
diff --git a/pyproject.toml b/pyproject.toml
index 0bd3c029..a4458eb6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,3 +1,17 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
[project.urls]
homepage = "https://github.com/huggingface/lerobot"
issues = "https://github.com/huggingface/lerobot/issues"
@@ -8,18 +22,19 @@ name = "lerobot"
version = "0.1.0"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
authors = [
- {name = "Rémi Cadène", email = "re.cadene@gmail.com"},
- {name = "Simon Alibert", email = "alibert.sim@gmail.com"},
- {name = "Alexander Soare", email = "alexander.soare159@gmail.com"},
- {name = "Quentin Gallouédec", email = "quentin.gallouedec@ec-lyon.fr"},
- {name = "Adil Zouitine", email = "adilzouitinegm@gmail.com"},
- {name = "Thomas Wolf", email = "thomaswolfcontact@gmail.com"},
+ { name = "Rémi Cadène", email = "re.cadene@gmail.com" },
+ { name = "Simon Alibert", email = "alibert.sim@gmail.com" },
+ { name = "Alexander Soare", email = "alexander.soare159@gmail.com" },
+ { name = "Quentin Gallouédec", email = "quentin.gallouedec@ec-lyon.fr" },
+ { name = "Adil Zouitine", email = "adilzouitinegm@gmail.com" },
+ { name = "Thomas Wolf", email = "thomaswolfcontact@gmail.com" },
+ { name = "Steven Palma", email = "imstevenpmwork@ieee.org" },
]
readme = "README.md"
-license = {text = "Apache-2.0"}
+license = { text = "Apache-2.0" }
requires-python = ">=3.10"
keywords = ["robotics", "deep learning", "pytorch"]
-classifiers=[
+classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Education",
@@ -38,7 +53,7 @@ dependencies = [
"einops>=0.8.0",
"flask>=3.0.3",
"gdown>=5.1.0",
- "gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
+ "gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
"h5py>=3.10.0",
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
"hydra-core>=1.3.2",
@@ -48,7 +63,7 @@ dependencies = [
"omegaconf>=2.3.0",
"opencv-python>=4.9.0",
"packaging>=24.2",
- "pyav>=12.0.5",
+ "av>=12.0.5",
"pymunk>=6.6.0",
"pynput>=1.7.7",
"pyzmq>=26.2.1",
@@ -63,7 +78,9 @@ dependencies = [
[project.optional-dependencies]
aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"]
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"]
-dora = ["gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'"]
+dora = [
+ "gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'",
+]
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
@@ -73,7 +90,7 @@ stretch = [
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
"pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
- "pynput>=1.7.7"
+ "pynput>=1.7.7",
]
test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"]
umi = ["imagecodecs>=2024.1.1"]
@@ -111,15 +128,24 @@ exclude = [
"venv",
]
-
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
+[tool.bandit]
+exclude_dirs = [
+ "tests",
+ "benchmarks",
+ "lerobot/common/datasets/push_dataset_to_hub",
+ "lerobot/common/datasets/v2/convert_dataset_v1_to_v2",
+ "lerobot/common/policies/pi0/conversion_scripts",
+ "lerobot/scripts/push_dataset_to_hub.py",
+]
+skips = ["B101", "B311", "B404", "B603"]
[tool.typos]
default.extend-ignore-re = [
- "(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
- "(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on" # spellchecker:
+ "(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
+ "(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on", # spellchecker:
]
default.extend-ignore-identifiers-re = [
# Add individual words here to ignore them
diff --git a/tests/__init__.py b/tests/__init__.py
index e69de29b..f52df1bd 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/configs/test_plugin_loading.py b/tests/configs/test_plugin_loading.py
new file mode 100644
index 00000000..1a8cceed
--- /dev/null
+++ b/tests/configs/test_plugin_loading.py
@@ -0,0 +1,89 @@
+import sys
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Generator
+
+import pytest
+
+from lerobot.common.envs.configs import EnvConfig
+from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap
+
+
+def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str:
+ """Creates a dummy plugin module that implements its own EnvConfig subclass."""
+ return f"""
+from dataclasses import dataclass
+from lerobot.common.envs.configs import {base_class}
+
+@{base_class}.register_subclass("{plugin_name}")
+@dataclass
+class TestPluginConfig:
+ value: int = 42
+ """
+
+
+@pytest.fixture
+def plugin_dir(tmp_path: Path) -> Generator[Path, None, None]:
+ """Creates a temporary plugin package structure."""
+ plugin_pkg = tmp_path / "test_plugin"
+ plugin_pkg.mkdir()
+ (plugin_pkg / "__init__.py").touch()
+
+ with open(plugin_pkg / "my_plugin.py", "w") as f:
+ f.write(create_plugin_code())
+
+ # Add tmp_path to Python path so we can import from it
+ sys.path.insert(0, str(tmp_path))
+ yield plugin_pkg
+ sys.path.pop(0)
+
+
+def test_parse_plugin_args():
+ cli_args = [
+ "--env.type=test",
+ "--model.discover_packages_path=some.package",
+ "--env.discover_packages_path=other.package",
+ ]
+ plugin_args = parse_plugin_args("discover_packages_path", cli_args)
+ assert plugin_args == {
+ "model.discover_packages_path": "some.package",
+ "env.discover_packages_path": "other.package",
+ }
+
+
+def test_load_plugin_success(plugin_dir: Path):
+ # Import should work and register the plugin with the real EnvConfig
+ load_plugin("test_plugin")
+
+ assert "test_env" in EnvConfig.get_known_choices()
+ plugin_cls = EnvConfig.get_choice_class("test_env")
+ plugin_instance = plugin_cls()
+ assert plugin_instance.value == 42
+
+
+def test_load_plugin_failure():
+ with pytest.raises(PluginLoadError) as exc_info:
+ load_plugin("nonexistent_plugin")
+ assert "Failed to load plugin 'nonexistent_plugin'" in str(exc_info.value)
+
+
+def test_wrap_with_plugin(plugin_dir: Path):
+ @dataclass
+ class Config:
+ env: EnvConfig
+
+ @wrap()
+ def dummy_func(cfg: Config):
+ return cfg
+
+ # Test loading plugin via CLI args
+ sys.argv = [
+ "dummy_script.py",
+ "--env.discover_packages_path=test_plugin",
+ "--env.type=test_env",
+ ]
+
+ cfg = dummy_func()
+ assert isinstance(cfg, Config)
+ assert isinstance(cfg.env, EnvConfig.get_choice_class("test_env"))
+ assert cfg.env.value == 42
diff --git a/tests/conftest.py b/tests/conftest.py
index d8d5a17c..4c823709 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -36,51 +36,27 @@ def pytest_collection_finish():
print(f"\nTesting with {DEVICE=}")
-@pytest.fixture
-def is_robot_available(robot_type):
- if robot_type not in available_robots:
+def _check_component_availability(component_type, available_components, make_component):
+ """Generic helper to check if a hardware component is available"""
+ if component_type not in available_components:
raise ValueError(
- f"The robot type '{robot_type}' is not valid. Expected one of these '{available_robots}"
+ f"The {component_type} type is not valid. Expected one of these '{available_components}'"
)
try:
- robot = make_robot(robot_type)
- robot.connect()
- del robot
+ component = make_component(component_type)
+ component.connect()
+ del component
return True
except Exception as e:
- print(f"\nA {robot_type} robot is not available.")
+ print(f"\nA {component_type} is not available.")
if isinstance(e, ModuleNotFoundError):
print(f"\nInstall module '{e.name}'")
elif isinstance(e, SerialException):
- print("\nNo physical motors bus detected.")
- else:
- traceback.print_exc()
-
- return False
-
-
-@pytest.fixture
-def is_camera_available(camera_type):
- if camera_type not in available_cameras:
- raise ValueError(
- f"The camera type '{camera_type}' is not valid. Expected one of these '{available_cameras}"
- )
-
- try:
- camera = make_camera(camera_type)
- camera.connect()
- del camera
- return True
-
- except Exception as e:
- print(f"\nA {camera_type} camera is not available.")
-
- if isinstance(e, ModuleNotFoundError):
- print(f"\nInstall module '{e.name}'")
- elif isinstance(e, ValueError) and "camera_index" in e.args[0]:
+ print("\nNo physical device detected.")
+ elif isinstance(e, ValueError) and "camera_index" in str(e):
print("\nNo physical camera detected.")
else:
traceback.print_exc()
@@ -88,30 +64,19 @@ def is_camera_available(camera_type):
return False
+@pytest.fixture
+def is_robot_available(robot_type):
+ return _check_component_availability(robot_type, available_robots, make_robot)
+
+
+@pytest.fixture
+def is_camera_available(camera_type):
+ return _check_component_availability(camera_type, available_cameras, make_camera)
+
+
@pytest.fixture
def is_motor_available(motor_type):
- if motor_type not in available_motors:
- raise ValueError(
- f"The motor type '{motor_type}' is not valid. Expected one of these '{available_motors}"
- )
-
- try:
- motors_bus = make_motors_bus(motor_type)
- motors_bus.connect()
- del motors_bus
- return True
-
- except Exception as e:
- print(f"\nA {motor_type} motor is not available.")
-
- if isinstance(e, ModuleNotFoundError):
- print(f"\nInstall module '{e.name}'")
- elif isinstance(e, SerialException):
- print("\nNo physical motors bus detected.")
- else:
- traceback.print_exc()
-
- return False
+ return _check_component_availability(motor_type, available_motors, make_motors_bus)
@pytest.fixture
diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py
index 3201dcf2..5e5c762c 100644
--- a/tests/fixtures/constants.py
+++ b/tests/fixtures/constants.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from lerobot.common.constants import HF_LEROBOT_HOME
LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing"
diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py
index 2259e0e6..531977da 100644
--- a/tests/fixtures/dataset_factories.py
+++ b/tests/fixtures/dataset_factories.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import random
from functools import partial
from pathlib import Path
diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py
index 4ef12e49..678d1f38 100644
--- a/tests/fixtures/files.py
+++ b/tests/fixtures/files.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import json
from pathlib import Path
diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py
index ae309cb4..aa2768e4 100644
--- a/tests/fixtures/hub.py
+++ b/tests/fixtures/hub.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from pathlib import Path
import datasets
diff --git a/tests/fixtures/optimizers.py b/tests/fixtures/optimizers.py
index 1a9b9d11..65488566 100644
--- a/tests/fixtures/optimizers.py
+++ b/tests/fixtures/optimizers.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import pytest
import torch
diff --git a/tests/mock_cv2.py b/tests/mock_cv2.py
index 806e35ed..eeaf859c 100644
--- a/tests/mock_cv2.py
+++ b/tests/mock_cv2.py
@@ -1,7 +1,25 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from functools import cache
import numpy as np
+CAP_V4L2 = 200
+CAP_DSHOW = 700
+CAP_AVFOUNDATION = 1200
+CAP_ANY = -1
+
CAP_PROP_FPS = 5
CAP_PROP_FRAME_WIDTH = 3
CAP_PROP_FRAME_HEIGHT = 4
diff --git a/tests/mock_dynamixel_sdk.py b/tests/mock_dynamixel_sdk.py
index a790dff0..ee399f96 100644
--- a/tests/mock_dynamixel_sdk.py
+++ b/tests/mock_dynamixel_sdk.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
diff --git a/tests/mock_pyrealsense2.py b/tests/mock_pyrealsense2.py
index 5a39fc2b..c477eb06 100644
--- a/tests/mock_pyrealsense2.py
+++ b/tests/mock_pyrealsense2.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import enum
import numpy as np
diff --git a/tests/mock_scservo_sdk.py b/tests/mock_scservo_sdk.py
index ca9233b0..37f6d0d5 100644
--- a/tests/mock_scservo_sdk.py
+++ b/tests/mock_scservo_sdk.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py
index 03726163..60fd9fc0 100644
--- a/tests/scripts/save_policy_to_safetensors.py
+++ b/tests/scripts/save_policy_to_safetensors.py
@@ -33,12 +33,11 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs),
- device="cpu",
)
train_cfg.validate() # Needed for auto-setting some parameters
dataset = make_dataset(train_cfg)
- policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=train_cfg.device)
+ policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
diff --git a/tests/test_cameras.py b/tests/test_cameras.py
index 54f39df0..9cce74c4 100644
--- a/tests/test_cameras.py
+++ b/tests/test_cameras.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
Tests for physical cameras and their mocked versions.
If the physical camera is not connected to the computer, or not working,
@@ -72,8 +85,8 @@ def test_camera(request, camera_type, mock):
camera.connect()
assert camera.is_connected
assert camera.fps is not None
- assert camera.width is not None
- assert camera.height is not None
+ assert camera.capture_width is not None
+ assert camera.capture_height is not None
# Test connecting twice raises an error
with pytest.raises(DeviceAlreadyConnectedError):
@@ -191,3 +204,49 @@ def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
# Small `record_time_s` to speedup unit tests
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)
+
+
+@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
+@require_camera
+def test_camera_rotation(request, camera_type, mock):
+ config_kwargs = {"camera_type": camera_type, "mock": mock, "width": 640, "height": 480, "fps": 30}
+
+ # No rotation.
+ camera = make_camera(**config_kwargs, rotation=None)
+ camera.connect()
+ assert camera.capture_width == 640
+ assert camera.capture_height == 480
+ assert camera.width == 640
+ assert camera.height == 480
+ no_rot_img = camera.read()
+ h, w, c = no_rot_img.shape
+ assert h == 480 and w == 640 and c == 3
+ camera.disconnect()
+
+ # Rotation = 90 (clockwise).
+ camera = make_camera(**config_kwargs, rotation=90)
+ camera.connect()
+ # With a 90° rotation, we expect the metadata dimensions to be swapped.
+ assert camera.capture_width == 640
+ assert camera.capture_height == 480
+ assert camera.width == 480
+ assert camera.height == 640
+ import cv2
+
+ assert camera.rotation == cv2.ROTATE_90_CLOCKWISE
+ rot_img = camera.read()
+ h, w, c = rot_img.shape
+ assert h == 640 and w == 480 and c == 3
+ camera.disconnect()
+
+ # Rotation = 180.
+ camera = make_camera(**config_kwargs, rotation=None)
+ camera.connect()
+ assert camera.capture_width == 640
+ assert camera.capture_height == 480
+ assert camera.width == 640
+ assert camera.height == 480
+ no_rot_img = camera.read()
+ h, w, c = no_rot_img.shape
+ assert h == 480 and w == 640 and c == 3
+ camera.disconnect()
diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py
index 6c718a09..cb54ae72 100644
--- a/tests/test_control_robot.py
+++ b/tests/test_control_robot.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
Tests for physical robots and their mocked versions.
If the physical robots are not connected to the computer, or not working,
@@ -39,7 +52,7 @@ from lerobot.configs.control import (
from lerobot.configs.policies import PreTrainedConfig
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from tests.test_robots import make_robot
-from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
+from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@@ -171,7 +184,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
replay(robot, replay_cfg)
policy_cfg = ACTConfig()
- policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE)
+ policy = make_policy(policy_cfg, ds_meta=dataset.meta)
out_dir = tmp_path / "logger"
@@ -216,8 +229,6 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
display_cameras=False,
play_sounds=False,
num_image_writer_processes=num_image_writer_processes,
- device=DEVICE,
- use_amp=False,
)
rec_eval_cfg.policy = PreTrainedConfig.from_pretrained(pretrained_policy_path)
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 852a967d..1a321e24 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -45,7 +45,7 @@ from lerobot.common.robots.utils import make_robot
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
-from tests.utils import DEVICE, require_x86_64_kernel
+from tests.utils import require_x86_64_kernel
@pytest.fixture
@@ -349,7 +349,6 @@ def test_factory(env_name, repo_id, policy_name):
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
env=make_env_config(env_name),
policy=make_policy_config(policy_name),
- device=DEVICE,
)
dataset = make_dataset(cfg)
diff --git a/tests/test_delta_timestamps.py b/tests/test_delta_timestamps.py
index b27cc1eb..35014642 100644
--- a/tests/test_delta_timestamps.py
+++ b/tests/test_delta_timestamps.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from itertools import accumulate
import datasets
diff --git a/tests/test_image_writer.py b/tests/test_image_writer.py
index c7fc11f2..802fe0d3 100644
--- a/tests/test_image_writer.py
+++ b/tests/test_image_writer.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import queue
import time
from multiprocessing import queues
diff --git a/tests/test_io_utils.py b/tests/test_io_utils.py
index d14f7adc..c1b776db 100644
--- a/tests/test_io_utils.py
+++ b/tests/test_io_utils.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import json
from pathlib import Path
from typing import Any
diff --git a/tests/test_logging_utils.py b/tests/test_logging_utils.py
index 72385496..1ba1829e 100644
--- a/tests/test_logging_utils.py
+++ b/tests/test_logging_utils.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import pytest
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
diff --git a/tests/test_motors.py b/tests/test_motors.py
index 0ad6b4d3..44c273d2 100644
--- a/tests/test_motors.py
+++ b/tests/test_motors.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
Tests for physical motors and their mocked versions.
If the physical motors are not connected to the computer, or not working,
diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py
index cf5c5b18..997e14fe 100644
--- a/tests/test_optimizers.py
+++ b/tests/test_optimizers.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import pytest
import torch
diff --git a/tests/test_policies.py b/tests/test_policies.py
index 9dab6176..5d7cca8f 100644
--- a/tests/test_policies.py
+++ b/tests/test_policies.py
@@ -143,12 +143,11 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs),
env=make_env_config(env_name, **env_kwargs),
- device=DEVICE,
)
# Check that we can make the policy object.
dataset = make_dataset(train_cfg)
- policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=DEVICE)
+ policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
assert isinstance(policy, PreTrainedPolicy)
# Check that we run select_actions and get the appropriate output.
@@ -214,7 +213,6 @@ def test_act_backbone_lr():
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
- device=DEVICE,
)
cfg.validate() # Needed for auto-setting some parameters
@@ -222,7 +220,7 @@ def test_act_backbone_lr():
assert cfg.policy.optimizer_lr_backbone == 0.001
dataset = make_dataset(cfg)
- policy = make_policy(cfg.policy, device=DEVICE, ds_meta=dataset.meta)
+ policy = make_policy(cfg.policy, ds_meta=dataset.meta)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr
@@ -254,10 +252,11 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
policy = policy_cls(policy_cfg)
+ policy.to(policy_cfg.device)
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
policy.save_pretrained(save_dir)
- policy_ = policy_cls.from_pretrained(save_dir, config=policy_cfg)
- assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
+ loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
+ torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
@@ -369,7 +368,7 @@ def test_normalize(insert_temporal_dim):
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
# to test with `policy.use_mpc=false`.
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
- ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
+ # ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
diff --git a/tests/test_random_utils.py b/tests/test_random_utils.py
index 8eee2b68..daf08a89 100644
--- a/tests/test_random_utils.py
+++ b/tests/test_random_utils.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import random
import numpy as np
diff --git a/tests/test_robots.py b/tests/test_robots.py
index 6def3a87..51a80195 100644
--- a/tests/test_robots.py
+++ b/tests/test_robots.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
Tests for physical robots and their mocked versions.
If the physical robots are not connected to the computer, or not working,
diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py
index e871fee1..17637663 100644
--- a/tests/test_schedulers.py
+++ b/tests/test_schedulers.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from torch.optim.lr_scheduler import LambdaLR
from lerobot.common.constants import SCHEDULER_STATE
diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py
index d6ed0063..b78f6e49 100644
--- a/tests/test_train_utils.py
+++ b/tests/test_train_utils.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from pathlib import Path
from unittest.mock import Mock, patch
diff --git a/tests/test_utils.py b/tests/test_utils.py
index b2f14694..2d0efc5a 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,3 +1,16 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import torch
from datasets import Dataset