This commit is contained in:
DUDULRX 2025-03-28 18:15:08 +08:00
commit bb9e77e712
98 changed files with 1287 additions and 161 deletions

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Misc # Misc
.git .git
tmp tmp

14
.gitattributes vendored
View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*.memmap filter=lfs diff=lfs merge=lfs -text *.memmap filter=lfs diff=lfs merge=lfs -text
*.stl filter=lfs diff=lfs merge=lfs -text *.stl filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text *.safetensors filter=lfs diff=lfs merge=lfs -text

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "\U0001F41B Bug Report" name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve LeRobot description: Submit a bug report to help us improve LeRobot
body: body:

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Inspired by # Inspired by
# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml # https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml
name: Builds name: Builds

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Inspired by # Inspired by
# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml # https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml
name: Nightly name: Nightly

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/huggingface/diffusers/blob/main/.github/workflows/pr_style_bot.yml # Adapted from https://github.com/huggingface/diffusers/blob/main/.github/workflows/pr_style_bot.yml
name: PR Style Bot name: PR Style Bot

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: Quality name: Quality
on: on:

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Inspired by # Inspired by
# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml # https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml
name: Test Dockerfiles name: Test Dockerfiles

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: Tests name: Tests
on: on:

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
on: on:
push: push:

14
.gitignore vendored
View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Logging # Logging
logs logs
tmp tmp

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
exclude: ^(tests/data) exclude: ^(tests/data)
default_language_version: default_language_version:
python: python3.10 python: python3.10

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
.PHONY: tests .PHONY: tests
PYTHON_PATH := $(shell which python) PYTHON_PATH := $(shell which python)
@ -33,6 +47,7 @@ test-act-ete-train:
--policy.dim_model=64 \ --policy.dim_model=64 \
--policy.n_action_steps=20 \ --policy.n_action_steps=20 \
--policy.chunk_size=20 \ --policy.chunk_size=20 \
--policy.device=$(DEVICE) \
--env.type=aloha \ --env.type=aloha \
--env.episode_length=5 \ --env.episode_length=5 \
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \ --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
@ -47,7 +62,6 @@ test-act-ete-train:
--save_checkpoint=true \ --save_checkpoint=true \
--log_freq=1 \ --log_freq=1 \
--wandb.enable=false \ --wandb.enable=false \
--device=$(DEVICE) \
--output_dir=tests/outputs/act/ --output_dir=tests/outputs/act/
test-act-ete-train-resume: test-act-ete-train-resume:
@ -58,11 +72,11 @@ test-act-ete-train-resume:
test-act-ete-eval: test-act-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \ --policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=aloha \ --env.type=aloha \
--env.episode_length=5 \ --env.episode_length=5 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--eval.batch_size=1 \ --eval.batch_size=1
--device=$(DEVICE)
test-diffusion-ete-train: test-diffusion-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
@ -70,6 +84,7 @@ test-diffusion-ete-train:
--policy.down_dims='[64,128,256]' \ --policy.down_dims='[64,128,256]' \
--policy.diffusion_step_embed_dim=32 \ --policy.diffusion_step_embed_dim=32 \
--policy.num_inference_steps=10 \ --policy.num_inference_steps=10 \
--policy.device=$(DEVICE) \
--env.type=pusht \ --env.type=pusht \
--env.episode_length=5 \ --env.episode_length=5 \
--dataset.repo_id=lerobot/pusht \ --dataset.repo_id=lerobot/pusht \
@ -84,21 +99,21 @@ test-diffusion-ete-train:
--save_freq=2 \ --save_freq=2 \
--log_freq=1 \ --log_freq=1 \
--wandb.enable=false \ --wandb.enable=false \
--device=$(DEVICE) \
--output_dir=tests/outputs/diffusion/ --output_dir=tests/outputs/diffusion/
test-diffusion-ete-eval: test-diffusion-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \ --policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=pusht \ --env.type=pusht \
--env.episode_length=5 \ --env.episode_length=5 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--eval.batch_size=1 \ --eval.batch_size=1
--device=$(DEVICE)
test-tdmpc-ete-train: test-tdmpc-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
--policy.type=tdmpc \ --policy.type=tdmpc \
--policy.device=$(DEVICE) \
--env.type=xarm \ --env.type=xarm \
--env.task=XarmLift-v0 \ --env.task=XarmLift-v0 \
--env.episode_length=5 \ --env.episode_length=5 \
@ -114,15 +129,14 @@ test-tdmpc-ete-train:
--save_freq=2 \ --save_freq=2 \
--log_freq=1 \ --log_freq=1 \
--wandb.enable=false \ --wandb.enable=false \
--device=$(DEVICE) \
--output_dir=tests/outputs/tdmpc/ --output_dir=tests/outputs/tdmpc/
test-tdmpc-ete-eval: test-tdmpc-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \ --policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=xarm \ --env.type=xarm \
--env.episode_length=5 \ --env.episode_length=5 \
--env.task=XarmLift-v0 \ --env.task=XarmLift-v0 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--eval.batch_size=1 \ --eval.batch_size=1
--device=$(DEVICE)

View File

@ -92,20 +92,15 @@ git clone https://github.com/huggingface/lerobot.git
cd lerobot cd lerobot
``` ```
Create a virtual environment with Python 3.10 and activate it using [`uv`](https://github.com/astral-sh/uv): Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
```bash ```bash
# Install uv if you haven't already conda create -y -n lerobot python=3.10
curl -LsSf https://astral.sh/uv/install.sh | sh conda activate lerobot
# Create and activate virtual environment with Python 3.10
uv venv .venv --python=3.10
source .venv/bin/activate # On Unix/macOS
# .venv\Scripts\activate # On Windows
``` ```
Install 🤗 LeRobot: Install 🤗 LeRobot:
```bash ```bash
uv pip install -e . pip install -e .
``` ```
> **NOTE:** Depending on your platform, If you encounter any build errors during this step > **NOTE:** Depending on your platform, If you encounter any build errors during this step
@ -389,3 +384,6 @@ Additionally, if you are using any of the particular policy architecture, pretra
year={2024} year={2024}
} }
``` ```
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline)

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face. This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch. It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first. training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment. """This scripts demonstrates how to train Diffusion Policy on the PushT environment.
Once you have trained a model with this script, you can try to evaluate it on Once you have trained a model with this script, you can try to evaluate it on

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data. """This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
This technique can be useful for debugging and testing purposes, as well as identifying whether a policy This technique can be useful for debugging and testing purposes, as well as identifying whether a policy

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil import shutil
from pathlib import Path from pathlib import Path

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# keys # keys
import os import os
from pathlib import Path from pathlib import Path

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import packaging.version import packaging.version
V2_MESSAGE = """ V2_MESSAGE = """

View File

@ -167,8 +167,8 @@ def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
for motor in robot_cfg.leader_arms[arm].motors for motor in robot_cfg.leader_arms[arm].motors
] ]
elif robot_cfg.type == "roarm_m3": elif robot_cfg.type == "roarm_m3":
state_names = ["roam_m3","roam_m3","roam_m3","roam_m3","roam_m3","roam_m3"] state_names = ["roam_m3", "roam_m3", "roam_m3", "roam_m3", "roam_m3", "roam_m3"]
action_names = ["roam_m3","roam_m3","roam_m3","roam_m3","roam_m3","roam_m3"] action_names = ["roam_m3", "roam_m3", "roam_m3", "roam_m3", "roam_m3", "roam_m3"]
# elif robot_cfg["robot_type"] == "stretch3": TODO # elif robot_cfg["robot_type"] == "stretch3": TODO
else: else:
raise NotImplementedError( raise NotImplementedError(

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import traceback import traceback
from pathlib import Path from pathlib import Path

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
2.1. It will: 2.1. It will:

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np import numpy as np

View File

@ -1 +1,15 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401 from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc import abc
from dataclasses import dataclass, field from dataclasses import dataclass, field

View File

@ -1 +1,15 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .optimizers import OptimizerConfig as OptimizerConfig from .optimizers import OptimizerConfig as OptimizerConfig

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .act.configuration_act import ACTConfig as ACTConfig from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config from .pi0.configuration_pi0 import PI0Config as PI0Config

View File

@ -16,7 +16,6 @@
import logging import logging
import torch
from torch import nn from torch import nn
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
@ -76,7 +75,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
def make_policy( def make_policy(
cfg: PreTrainedConfig, cfg: PreTrainedConfig,
device: str | torch.device,
ds_meta: LeRobotDatasetMetadata | None = None, ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None, env_cfg: EnvConfig | None = None,
) -> PreTrainedPolicy: ) -> PreTrainedPolicy:
@ -88,7 +86,6 @@ def make_policy(
Args: Args:
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
be loaded with the weights from that path. be loaded with the weights from that path.
device (str): the device to load the policy onto.
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None. statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
@ -96,7 +93,7 @@ def make_policy(
Raises: Raises:
ValueError: Either ds_meta or env and env_cfg must be provided. ValueError: Either ds_meta or env and env_cfg must be provided.
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility) NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
Returns: Returns:
PreTrainedPolicy: _description_ PreTrainedPolicy: _description_
@ -111,7 +108,7 @@ def make_policy(
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment # https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be # variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
# slower than running natively on MPS. # slower than running natively on MPS.
if cfg.type == "vqbet" and str(device) == "mps": if cfg.type == "vqbet" and cfg.device == "mps":
raise NotImplementedError( raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. " "Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend." "Please use `cpu` or `cuda` backend."
@ -145,7 +142,7 @@ def make_policy(
# Make a fresh policy. # Make a fresh policy.
policy = policy_cls(**kwargs) policy = policy_cls(**kwargs)
policy.to(device) policy.to(cfg.device)
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
# policy = torch.compile(policy, mode="reduce-overhead") # policy = torch.compile(policy, mode="reduce-overhead")

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig from lerobot.common.optim.optimizers import AdamWConfig
@ -76,6 +90,7 @@ class PI0Config(PreTrainedConfig):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
# TODO(Steven): Validate device and amp? in all policy configs?
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size: if self.n_action_steps > self.chunk_size:
raise ValueError( raise ValueError(

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@ -31,7 +45,7 @@ def main():
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, device, ds_meta=dataset.meta) policy = make_policy(cfg, ds_meta=dataset.meta)
# policy = torch.compile(policy, mode="reduce-overhead") # policy = torch.compile(policy, mode="reduce-overhead")

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import pickle import pickle
from pathlib import Path from pathlib import Path
@ -87,7 +101,7 @@ def main():
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, device, dataset_meta) policy = make_policy(cfg, dataset_meta)
# loss_dict = policy.forward(batch, noise=noise, time=time_beta) # loss_dict = policy.forward(batch, noise=noise, time=time_beta)
# loss_dict["loss"].backward() # loss_dict["loss"].backward()

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import GemmaConfig, PaliGemmaConfig from transformers import GemmaConfig, PaliGemmaConfig

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Convert pi0 parameters from Jax to Pytorch Convert pi0 parameters from Jax to Pytorch

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from packaging.version import Version from packaging.version import Version

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc import abc
import logging import logging
import os import os
@ -73,7 +86,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
cache_dir: str | Path | None = None, cache_dir: str | Path | None = None,
local_files_only: bool = False, local_files_only: bool = False,
revision: str | None = None, revision: str | None = None,
map_location: str = "cpu",
strict: bool = False, strict: bool = False,
**kwargs, **kwargs,
) -> T: ) -> T:
@ -98,7 +110,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
if os.path.isdir(model_id): if os.path.isdir(model_id):
print("Loading weights from local directory") print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
policy = cls._load_as_safetensor(instance, model_file, map_location, strict) policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
else: else:
try: try:
model_file = hf_hub_download( model_file = hf_hub_download(
@ -112,13 +124,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
token=token, token=token,
local_files_only=local_files_only, local_files_only=local_files_only,
) )
policy = cls._load_as_safetensor(instance, model_file, map_location, strict) policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
except HfHubHTTPError as e: except HfHubHTTPError as e:
raise FileNotFoundError( raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}" f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e ) from e
policy.to(map_location) policy.to(config.device)
policy.eval() policy.eval()
return policy return policy

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc import abc
from dataclasses import dataclass from dataclasses import dataclass

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This file contains utilities for recording frames from Intel Realsense cameras. This file contains utilities for recording frames from Intel Realsense cameras.
""" """

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring. This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
""" """

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Protocol from typing import Protocol
import numpy as np import numpy as np

View File

@ -1,14 +1,25 @@
import logging # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
import draccus import draccus
from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
@dataclass @dataclass
@ -43,11 +54,6 @@ class RecordControlConfig(ControlConfig):
# Root directory where the dataset will be stored (e.g. 'dataset/path'). # Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None root: str | Path | None = None
policy: PreTrainedConfig | None = None policy: PreTrainedConfig | None = None
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
device: str | None = None # cuda | cpu | mps
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool | None = None
# Limit the frames per second. By default, uses the policy fps. # Limit the frames per second. By default, uses the policy fps.
fps: int | None = None fps: int | None = None
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize. # Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
@ -90,27 +96,6 @@ class RecordControlConfig(ControlConfig):
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path self.policy.pretrained_path = policy_path
# When no device or use_amp are given, use the one from training config.
if self.device is None or self.use_amp is None:
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
if self.device is None:
self.device = train_cfg.device
if self.use_amp is None:
self.use_amp = train_cfg.use_amp
# Automatically switch to available device if necessary
if not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@ControlConfig.register_subclass("replay") @ControlConfig.register_subclass("replay")
@dataclass @dataclass

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################################## ########################################################################################
# Utilities # Utilities
######################################################################################## ########################################################################################
@ -18,6 +32,7 @@ from termcolor import colored
from lerobot.common.datasets.image_writer import safe_stop_image_writer from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, has_method from lerobot.common.utils.utils import get_safe_torch_device, has_method
@ -210,8 +225,6 @@ def record_episode(
episode_time_s, episode_time_s,
display_cameras, display_cameras,
policy, policy,
device,
use_amp,
fps, fps,
single_task, single_task,
): ):
@ -222,8 +235,6 @@ def record_episode(
dataset=dataset, dataset=dataset,
events=events, events=events,
policy=policy, policy=policy,
device=device,
use_amp=use_amp,
fps=fps, fps=fps,
teleoperate=policy is None, teleoperate=policy is None,
single_task=single_task, single_task=single_task,
@ -238,9 +249,7 @@ def control_loop(
display_cameras=False, display_cameras=False,
dataset: LeRobotDataset | None = None, dataset: LeRobotDataset | None = None,
events=None, events=None,
policy=None, policy: PreTrainedPolicy = None,
device: torch.device | str | None = None,
use_amp: bool | None = None,
fps: int | None = None, fps: int | None = None,
single_task: str | None = None, single_task: str | None = None,
): ):
@ -263,9 +272,6 @@ def control_loop(
if dataset is not None and fps is not None and dataset.fps != fps: if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
if isinstance(device, str):
device = get_safe_torch_device(device)
timestamp = 0 timestamp = 0
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
while timestamp < control_time_s: while timestamp < control_time_s:
@ -277,7 +283,9 @@ def control_loop(
observation = robot.capture_observation() observation = robot.capture_observation()
if policy is not None: if policy is not None:
pred_action = predict_action(observation, policy, device, use_amp) pred_action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
# Action can eventually be clipped using `max_relative_target`, # Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. # so action actually sent is saved in the dataset.
action = robot.send_action(pred_action) action = robot.send_action(pred_action)

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc import abc
from dataclasses import dataclass from dataclasses import dataclass

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum import enum
import logging import logging
import math import math

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum import enum
import logging import logging
import math import math

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Protocol from typing import Protocol
from lerobot.common.robot_devices.motors.configs import ( from lerobot.common.robot_devices.motors.configs import (

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc import abc
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Sequence from typing import Sequence

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Logic to calibrate a robot arm built with dynamixel motors""" """Logic to calibrate a robot arm built with dynamixel motors"""
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring # TODO(rcadene, aliberts): move this logic into the robot code when refactoring

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Logic to calibrate a robot arm built with feetech motors""" """Logic to calibrate a robot arm built with feetech motors"""
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring # TODO(rcadene, aliberts): move this logic into the robot code when refactoring

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64 import base64
import json import json
import threading import threading

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains logic to instantiate a robot, read information from its motors and cameras, """Contains logic to instantiate a robot, read information from its motors and cameras,
and send orders to its motors. and send orders to its motors.
""" """

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64 import base64
import json import json
import os import os

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Protocol from typing import Protocol
from lerobot.common.robot_devices.robots.configs import ( from lerobot.common.robot_devices.robots.configs import (

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform import platform
import time import time

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any, Type, TypeVar from typing import Any, Type, TypeVar

View File

@ -51,8 +51,10 @@ def auto_select_torch_device() -> torch.device:
return torch.device("cpu") return torch.device("cpu")
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available.""" """Given a string, return a torch.device with checks on whether the device is available."""
try_device = str(try_device)
match try_device: match try_device:
case "cuda": case "cuda":
assert torch.cuda.is_available() assert torch.cuda.is_available()
@ -85,6 +87,7 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
def is_torch_device_available(try_device: str) -> bool: def is_torch_device_available(try_device: str) -> bool:
try_device = str(try_device) # Ensure try_device is a string
if try_device == "cuda": if try_device == "cuda":
return torch.cuda.is_available() return torch.cuda.is_available()
elif try_device == "mps": elif try_device == "mps":
@ -92,7 +95,7 @@ def is_torch_device_available(try_device: str) -> bool:
elif try_device == "cpu": elif try_device == "cpu":
return True return True
else: else:
raise ValueError(f"Unknown device '{try_device}.") raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
def is_amp_available(device: str): def is_amp_available(device: str):

View File

@ -1,14 +1,26 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime as dt import datetime as dt
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from lerobot.common import envs, policies # noqa: F401 from lerobot.common import envs, policies # noqa: F401
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.default import EvalConfig from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
@dataclass @dataclass
@ -21,11 +33,6 @@ class EvalPipelineConfig:
policy: PreTrainedConfig | None = None policy: PreTrainedConfig | None = None
output_dir: Path | None = None output_dir: Path | None = None
job_name: str | None = None job_name: str | None = None
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
device: str | None = None # cuda | cpu | mps
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
seed: int | None = 1000 seed: int | None = 1000
def __post_init__(self): def __post_init__(self):
@ -36,27 +43,6 @@ class EvalPipelineConfig:
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path self.policy.pretrained_path = policy_path
# When no device or use_amp are given, use the one from training config.
if self.device is None or self.use_amp is None:
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
if self.device is None:
self.device = train_cfg.device
if self.use_amp is None:
self.use_amp = train_cfg.use_amp
# Automatically switch to available device if necessary
if not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
else: else:
logging.warning( logging.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)." "No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
@ -73,11 +59,6 @@ class EvalPipelineConfig:
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/eval") / eval_dir self.output_dir = Path("outputs/eval") / eval_dir
if self.device is None:
raise ValueError("Set one of the following device: cuda, cpu or mps")
elif self.device == "cuda" and self.use_amp is None:
raise ValueError("Set 'use_amp' to True or False.")
@classmethod @classmethod
def __get_path_fields__(cls) -> list[str]: def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`""" """This enables the parser to load config from the policy using `--policy.path=local/dir`"""

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect import inspect
import sys import sys
from argparse import ArgumentError from argparse import ArgumentError

View File

@ -1,4 +1,18 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc import abc
import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@ -12,6 +26,7 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot.common.optim.optimizers import OptimizerConfig from lerobot.common.optim.optimizers import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin from lerobot.common.utils.hub import HubMixin
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
# Generic variable that is either PreTrainedConfig or a subclass thereof # Generic variable that is either PreTrainedConfig or a subclass thereof
@ -40,8 +55,24 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
input_features: dict[str, PolicyFeature] = field(default_factory=dict) input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict) output_features: dict[str, PolicyFeature] = field(default_factory=dict)
device: str | None = None # cuda | cpu | mp
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
def __post_init__(self): def __post_init__(self):
self.pretrained_path = None self.pretrained_path = None
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device.type
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@property @property
def type(self) -> str: def type(self) -> str:

View File

@ -1,5 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime as dt import datetime as dt
import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@ -13,7 +25,6 @@ from lerobot.common import envs
from lerobot.common.optim import OptimizerConfig from lerobot.common.optim import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin from lerobot.common.utils.hub import HubMixin
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
@ -35,10 +46,6 @@ class TrainPipelineConfig(HubMixin):
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint, # Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
# regardless of what's provided with the training command at the time of resumption. # regardless of what's provided with the training command at the time of resumption.
resume: bool = False resume: bool = False
device: str | None = None # cuda | cpu | mp
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
# `seed` is used for training (eg: model initialization, dataset shuffling) # `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments. # AND for the evaluation environments.
seed: int | None = 1000 seed: int | None = 1000
@ -61,18 +68,6 @@ class TrainPipelineConfig(HubMixin):
self.checkpoint_path = None self.checkpoint_path = None
def validate(self): def validate(self):
if not self.device:
logging.warning("No device specified, trying to infer device automatically")
device = auto_select_torch_device()
self.device = device.type
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
# HACK: We parse again the cli args here to get the pretrained paths if there was some. # HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy") policy_path = parser.get_path_arg("policy")
if policy_path: if policy_path:

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: We subclass str so that serialization is straightforward # Note: We subclass str so that serialization is straightforward
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json # https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
from dataclasses import dataclass from dataclasses import dataclass

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This script configure a single motor at a time to a given ID and baudrate. This script configure a single motor at a time to a given ID and baudrate.

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Utilities to control a robot. Utilities to control a robot.
@ -254,7 +267,7 @@ def record(
) )
# Load pretrained policy # Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta) policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
@ -285,8 +298,6 @@ def record(
episode_time_s=cfg.episode_time_s, episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras, display_cameras=cfg.display_cameras,
policy=policy, policy=policy,
device=cfg.device,
use_amp=cfg.use_amp,
fps=cfg.fps, fps=cfg.fps,
single_task=cfg.single_task, single_task=cfg.single_task,
) )

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Utilities to control a robot in simulation. Utilities to control a robot in simulation.

View File

@ -458,7 +458,7 @@ def eval_main(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg))) logging.info(pformat(asdict(cfg)))
# Check device is available # Check device is available
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -470,14 +470,14 @@ def eval_main(cfg: EvalPipelineConfig):
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Making policy.") logging.info("Making policy.")
policy = make_policy( policy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
device=device,
env_cfg=cfg.env, env_cfg=cfg.env,
) )
policy.eval() policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy( info = eval_policy(
env, env,
policy, policy,

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import time import time
from pathlib import Path from pathlib import Path

View File

@ -120,7 +120,7 @@ def train(cfg: TrainPipelineConfig):
set_seed(cfg.seed) set_seed(cfg.seed)
# Check device is available # Check device is available
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -138,13 +138,12 @@ def train(cfg: TrainPipelineConfig):
logging.info("Creating policy") logging.info("Creating policy")
policy = make_policy( policy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
device=device,
ds_meta=dataset.meta, ds_meta=dataset.meta,
) )
logging.info("Creating optimizer and scheduler") logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device, enabled=cfg.use_amp) grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
step = 0 # number of policy updates (forward + backward + optim) step = 0 # number of policy updates (forward + backward + optim)
@ -218,7 +217,7 @@ def train(cfg: TrainPipelineConfig):
cfg.optimizer.grad_clip_norm, cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler, grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp, use_amp=cfg.policy.use_amp,
) )
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@ -249,7 +248,10 @@ def train(cfg: TrainPipelineConfig):
if cfg.env and is_eval_step: if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps) step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): with (
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
eval_info = eval_policy( eval_info = eval_policy(
eval_env, eval_env,
policy, policy,

View File

@ -234,7 +234,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
This file will be loaded by Dygraph javascript to plot data in real time.""" This file will be loaded by Dygraph javascript to plot data in real time."""
columns = [] columns = []
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"] selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
selected_columns.remove("timestamp") selected_columns.remove("timestamp")
ignored_columns = [] ignored_columns = []

View File

@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[project.urls] [project.urls]
homepage = "https://github.com/huggingface/lerobot" homepage = "https://github.com/huggingface/lerobot"
issues = "https://github.com/huggingface/lerobot/issues" issues = "https://github.com/huggingface/lerobot/issues"
@ -8,18 +22,19 @@ name = "lerobot"
version = "0.1.0" version = "0.1.0"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
authors = [ authors = [
{name = "Rémi Cadène", email = "re.cadene@gmail.com"}, { name = "Rémi Cadène", email = "re.cadene@gmail.com" },
{name = "Simon Alibert", email = "alibert.sim@gmail.com"}, { name = "Simon Alibert", email = "alibert.sim@gmail.com" },
{name = "Alexander Soare", email = "alexander.soare159@gmail.com"}, { name = "Alexander Soare", email = "alexander.soare159@gmail.com" },
{name = "Quentin Gallouédec", email = "quentin.gallouedec@ec-lyon.fr"}, { name = "Quentin Gallouédec", email = "quentin.gallouedec@ec-lyon.fr" },
{name = "Adil Zouitine", email = "adilzouitinegm@gmail.com"}, { name = "Adil Zouitine", email = "adilzouitinegm@gmail.com" },
{name = "Thomas Wolf", email = "thomaswolfcontact@gmail.com"}, { name = "Thomas Wolf", email = "thomaswolfcontact@gmail.com" },
{ name = "Steven Palma", email = "imstevenpmwork@ieee.org" },
] ]
readme = "README.md" readme = "README.md"
license = {text = "Apache-2.0"} license = { text = "Apache-2.0" }
requires-python = ">=3.10" requires-python = ">=3.10"
keywords = ["robotics", "deep learning", "pytorch"] keywords = ["robotics", "deep learning", "pytorch"]
classifiers=[ classifiers = [
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"Intended Audience :: Education", "Intended Audience :: Education",
@ -38,7 +53,7 @@ dependencies = [
"einops>=0.8.0", "einops>=0.8.0",
"flask>=3.0.3", "flask>=3.0.3",
"gdown>=5.1.0", "gdown>=5.1.0",
"gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work "gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
"h5py>=3.10.0", "h5py>=3.10.0",
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'", "huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
"hydra-core>=1.3.2", "hydra-core>=1.3.2",
@ -64,7 +79,9 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"] aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"]
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"] dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"]
dora = ["gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'"] dora = [
"gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'",
]
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"] dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
@ -74,7 +91,7 @@ stretch = [
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'", "hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
"pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'", "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
"pynput>=1.7.7" "pynput>=1.7.7",
] ]
test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"] test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"]
umi = ["imagecodecs>=2024.1.1"] umi = ["imagecodecs>=2024.1.1"]
@ -129,8 +146,8 @@ skips = ["B101", "B311", "B404", "B603"]
[tool.typos] [tool.typos]
default.extend-ignore-re = [ default.extend-ignore-re = [
"(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line "(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
"(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on" # spellchecker:<on|off> "(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on", # spellchecker:<on|off>
] ]
default.extend-ignore-identifiers-re = [ default.extend-ignore-identifiers-re = [
# Add individual words here to ignore them # Add individual words here to ignore them

View File

@ -0,0 +1,13 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.constants import HF_LEROBOT_HOME
LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing" LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing"

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random import random
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
from pathlib import Path from pathlib import Path

13
tests/fixtures/hub.py vendored
View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path from pathlib import Path
import datasets import datasets

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest import pytest
import torch import torch

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cache from functools import cache
import numpy as np import numpy as np

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration """Mocked classes and functions from dynamixel_sdk to allow for continuous integration
and testing code logic that requires hardware and devices (e.g. robot arms, cameras) and testing code logic that requires hardware and devices (e.g. robot arms, cameras)

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum import enum
import numpy as np import numpy as np

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration """Mocked classes and functions from dynamixel_sdk to allow for continuous integration
and testing code logic that requires hardware and devices (e.g. robot arms, cameras) and testing code logic that requires hardware and devices (e.g. robot arms, cameras)

View File

@ -33,12 +33,11 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
# TODO(rcadene, aliberts): remove dataset download # TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs), policy=make_policy_config(policy_name, **policy_kwargs),
device="cpu",
) )
train_cfg.validate() # Needed for auto-setting some parameters train_cfg.validate() # Needed for auto-setting some parameters
dataset = make_dataset(train_cfg) dataset = make_dataset(train_cfg)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=train_cfg.device) policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
policy.train() policy.train()
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy) optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Tests for physical cameras and their mocked versions. Tests for physical cameras and their mocked versions.
If the physical camera is not connected to the computer, or not working, If the physical camera is not connected to the computer, or not working,

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Tests for physical robots and their mocked versions. Tests for physical robots and their mocked versions.
If the physical robots are not connected to the computer, or not working, If the physical robots are not connected to the computer, or not working,
@ -39,7 +52,7 @@ from lerobot.common.robot_devices.control_configs import (
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from tests.test_robots import make_robot from tests.test_robots import make_robot
from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@ -171,7 +184,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
replay(robot, replay_cfg) replay(robot, replay_cfg)
policy_cfg = ACTConfig() policy_cfg = ACTConfig()
policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE) policy = make_policy(policy_cfg, ds_meta=dataset.meta)
out_dir = tmp_path / "logger" out_dir = tmp_path / "logger"
@ -216,8 +229,6 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
display_cameras=False, display_cameras=False,
play_sounds=False, play_sounds=False,
num_image_writer_processes=num_image_writer_processes, num_image_writer_processes=num_image_writer_processes,
device=DEVICE,
use_amp=False,
) )
rec_eval_cfg.policy = PreTrainedConfig.from_pretrained(pretrained_policy_path) rec_eval_cfg.policy = PreTrainedConfig.from_pretrained(pretrained_policy_path)

View File

@ -45,7 +45,7 @@ from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.configs.default import DatasetConfig from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.utils import DEVICE, require_x86_64_kernel from tests.utils import require_x86_64_kernel
@pytest.fixture @pytest.fixture
@ -349,7 +349,6 @@ def test_factory(env_name, repo_id, policy_name):
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]), dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
env=make_env_config(env_name), env=make_env_config(env_name),
policy=make_policy_config(policy_name), policy=make_policy_config(policy_name),
device=DEVICE,
) )
dataset = make_dataset(cfg) dataset = make_dataset(cfg)

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from itertools import accumulate from itertools import accumulate
import datasets import datasets

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import queue import queue
import time import time
from multiprocessing import queues from multiprocessing import queues

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest import pytest
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Tests for physical motors and their mocked versions. Tests for physical motors and their mocked versions.
If the physical motors are not connected to the computer, or not working, If the physical motors are not connected to the computer, or not working,

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest import pytest
import torch import torch

View File

@ -143,12 +143,11 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs), policy=make_policy_config(policy_name, **policy_kwargs),
env=make_env_config(env_name, **env_kwargs), env=make_env_config(env_name, **env_kwargs),
device=DEVICE,
) )
# Check that we can make the policy object. # Check that we can make the policy object.
dataset = make_dataset(train_cfg) dataset = make_dataset(train_cfg)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=DEVICE) policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
assert isinstance(policy, PreTrainedPolicy) assert isinstance(policy, PreTrainedPolicy)
# Check that we run select_actions and get the appropriate output. # Check that we run select_actions and get the appropriate output.
@ -214,7 +213,6 @@ def test_act_backbone_lr():
# TODO(rcadene, aliberts): remove dataset download # TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]), dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001), policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
device=DEVICE,
) )
cfg.validate() # Needed for auto-setting some parameters cfg.validate() # Needed for auto-setting some parameters
@ -222,7 +220,7 @@ def test_act_backbone_lr():
assert cfg.policy.optimizer_lr_backbone == 0.001 assert cfg.policy.optimizer_lr_backbone == 0.001
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
policy = make_policy(cfg.policy, device=DEVICE, ds_meta=dataset.meta) policy = make_policy(cfg.policy, ds_meta=dataset.meta)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy) optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2 assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random import random
import numpy as np import numpy as np

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Tests for physical robots and their mocked versions. Tests for physical robots and their mocked versions.
If the physical robots are not connected to the computer, or not working, If the physical robots are not connected to the computer, or not working,

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from lerobot.common.constants import SCHEDULER_STATE from lerobot.common.constants import SCHEDULER_STATE

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path from pathlib import Path
from unittest.mock import Mock, patch from unittest.mock import Mock, patch

View File

@ -1,3 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
from datasets import Dataset from datasets import Dataset