130 lines
5.0 KiB
Python
130 lines
5.0 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import draccus
|
|
|
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
|
|
|
|
@dataclass
|
|
class ControlConfig(draccus.ChoiceRegistry):
|
|
pass
|
|
|
|
|
|
@ControlConfig.register_subclass("calibrate")
|
|
@dataclass
|
|
class CalibrateControlConfig(ControlConfig):
|
|
# List of arms to calibrate (e.g. `--arms='["left_follower","right_follower"]' left_leader`)
|
|
arms: list[str] | None = None
|
|
|
|
|
|
@ControlConfig.register_subclass("teleoperate")
|
|
@dataclass
|
|
class TeleoperateControlConfig(ControlConfig):
|
|
# Limit the maximum frames per second. By default, no limit.
|
|
fps: int | None = None
|
|
teleop_time_s: float | None = None
|
|
# Display all cameras on screen
|
|
display_cameras: bool = True
|
|
|
|
|
|
@ControlConfig.register_subclass("record")
|
|
@dataclass
|
|
class RecordControlConfig(ControlConfig):
|
|
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
|
repo_id: str
|
|
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
|
single_task: str
|
|
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
|
root: str | Path | None = None
|
|
policy: PreTrainedConfig | None = None
|
|
# Limit the frames per second. By default, uses the policy fps.
|
|
fps: int | None = None
|
|
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
|
warmup_time_s: int | float = 10
|
|
# Number of seconds for data recording for each episode.
|
|
episode_time_s: int | float = 60
|
|
# Number of seconds for resetting the environment after each episode.
|
|
reset_time_s: int | float = 60
|
|
# Number of episodes to record.
|
|
num_episodes: int = 50
|
|
# Encode frames in the dataset into video
|
|
video: bool = True
|
|
# Upload dataset to Hugging Face hub.
|
|
push_to_hub: bool = True
|
|
# Upload on private repository on the Hugging Face hub.
|
|
private: bool = False
|
|
# Add tags to your dataset on the hub.
|
|
tags: list[str] | None = None
|
|
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
|
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
|
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
|
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
|
num_image_writer_processes: int = 0
|
|
# Number of threads writing the frames as png images on disk, per camera.
|
|
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
|
# Not enough threads might cause low camera fps.
|
|
num_image_writer_threads_per_camera: int = 4
|
|
# Display all cameras on screen
|
|
display_cameras: bool = True
|
|
# Use vocal synthesis to read events.
|
|
play_sounds: bool = True
|
|
# Resume recording on an existing dataset.
|
|
resume: bool = False
|
|
|
|
def __post_init__(self):
|
|
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
|
policy_path = parser.get_path_arg("control.policy")
|
|
if policy_path:
|
|
cli_overrides = parser.get_cli_overrides("control.policy")
|
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
self.policy.pretrained_path = policy_path
|
|
|
|
|
|
@ControlConfig.register_subclass("replay")
|
|
@dataclass
|
|
class ReplayControlConfig(ControlConfig):
|
|
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
|
repo_id: str
|
|
# Index of the episode to replay.
|
|
episode: int
|
|
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
|
root: str | Path | None = None
|
|
# Limit the frames per second. By default, uses the dataset fps.
|
|
fps: int | None = None
|
|
# Use vocal synthesis to read events.
|
|
play_sounds: bool = True
|
|
|
|
|
|
@ControlConfig.register_subclass("remote_robot")
|
|
@dataclass
|
|
class RemoteRobotConfig(ControlConfig):
|
|
log_interval: int = 100
|
|
|
|
|
|
@dataclass
|
|
class ControlPipelineConfig:
|
|
robot: RobotConfig
|
|
control: ControlConfig
|
|
|
|
@classmethod
|
|
def __get_path_fields__(cls) -> list[str]:
|
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
|
return ["control.policy"]
|