This commit is contained in:
Remi Cadene 2024-07-02 21:35:24 +02:00
parent 47aac0dff7
commit 8a7aa50e97
9 changed files with 207 additions and 140 deletions

View File

@ -1,10 +1,9 @@
import argparse
import time
from dataclasses import dataclass, replace
from pathlib import Path
from threading import Thread
import time
import cv2
import numpy as np
@ -23,10 +22,13 @@ def find_camera_indices(raise_when_empty=False, max_index_search_range=60):
camera_ids.append(camera_idx)
if raise_when_empty and len(camera_ids) == 0:
raise OSError("Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, or your camera driver, or make sure your camera is compatible with opencv2.")
raise OSError(
"Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, or your camera driver, or make sure your camera is compatible with opencv2."
)
return camera_ids
def benchmark_cameras(cameras, out_dir=None, save_images=False, num_warmup_frames=4):
if out_dir:
out_dir = Path(out_dir)
@ -50,7 +52,7 @@ def benchmark_cameras(cameras, out_dir=None, save_images=False, num_warmup_frame
print(f"Write to {image_path}")
save_color_image(color_image, image_path, write_shape=True)
dt_s = (time.time() - now)
dt_s = time.time() - now
dt_ms = dt_s * 1000
freq = 1 / dt_s
print(f"Latency (ms): {dt_ms:.2f}\tFrequency: {freq:.2f}")
@ -73,14 +75,14 @@ class OpenCVCameraConfig:
OpenCVCameraConfig(30, 1280, 720)
```
"""
fps: int | None = None
width: int | None = None
height: int | None = None
color: str = "rgb"
class OpenCVCamera():
class OpenCVCamera:
# TODO(rcadene): improve dosctring
"""
https://docs.opencv.org/4.x/d0/da7/videoio_overview.html
@ -93,6 +95,7 @@ class OpenCVCamera():
color_image = camera.capture_image()
```
"""
def __init__(self, camera_index: int, config: OpenCVCameraConfig | None = None, **kwargs):
if config is None:
config = OpenCVCameraConfig()
@ -109,7 +112,9 @@ class OpenCVCamera():
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {self.color} is provided.")
if self.camera_index is None:
raise ValueError(f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {camera_index} is provided instead.")
raise ValueError(
f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {camera_index} is provided instead."
)
self.camera = None
self.is_connected = False
@ -134,7 +139,9 @@ class OpenCVCamera():
if not is_camera_open:
# Verify that the provided `camera_index` is valid before printing the traceback
if self.camera_index not in find_camera_indices():
raise ValueError(f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {self.camera_index} is provided instead.")
raise ValueError(
f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {self.camera_index} is provided instead."
)
raise OSError(f"Can't access camera {self.camera_index}.")
@ -155,11 +162,17 @@ class OpenCVCamera():
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
if self.fps and self.fps != actual_fps:
raise OSError(f"Can't set {self.fps=} for camera {self.camera_index}. Actual value is {actual_fps}.")
raise OSError(
f"Can't set {self.fps=} for camera {self.camera_index}. Actual value is {actual_fps}."
)
if self.width and self.width != actual_width:
raise OSError(f"Can't set {self.width=} for camera {self.camera_index}. Actual value is {actual_width}.")
raise OSError(
f"Can't set {self.width=} for camera {self.camera_index}. Actual value is {actual_width}."
)
if self.height and self.height != actual_height:
raise OSError(f"Can't set {self.height=} for camera {self.camera_index}. Actual value is {actual_height}.")
raise OSError(
f"Can't set {self.height=} for camera {self.camera_index}. Actual value is {actual_height}."
)
self.is_connected = True
self.t.start()
@ -172,10 +185,7 @@ class OpenCVCamera():
if not ret:
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
if temporary_color is None:
requested_color = self.color
else:
requested_color = temporary_color
requested_color = self.color if temporary_color is None else temporary_color
if requested_color not in ["rgb", "bgr"]:
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {requested_color} is provided.")
@ -215,6 +225,7 @@ def save_images_config(config: OpenCVCameraConfig, out_dir: Path):
out_dir = out_dir.parent / f"{out_dir.name}_{config.width}x{config.height}_{config.fps}"
benchmark_cameras(cameras, out_dir, save_images=True)
def benchmark_config(config: OpenCVCameraConfig, camera_ids: list[int]):
cameras = [OpenCVCamera(idx, config) for idx in camera_ids]
benchmark_cameras(cameras)
@ -222,7 +233,7 @@ def benchmark_config(config: OpenCVCameraConfig, camera_ids: list[int]):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, choices=["save_images", 'benchmark'], default="save_images")
parser.add_argument("--mode", type=str, choices=["save_images", "benchmark"], default="save_images")
parser.add_argument("--camera-ids", type=int, nargs="*", default=[16, 4, 22, 10])
parser.add_argument("--fps", type=int, default=30)
parser.add_argument("--width", type=str, default=640)
@ -242,8 +253,3 @@ if __name__ == "__main__":
benchmark_config(config, args.camera_ids)
else:
raise ValueError(args.mode)

View File

@ -1,15 +1,13 @@
from pathlib import Path
import time
import cv2
from typing import Protocol
import cv2
import numpy as np
def write_shape_on_image_inplace(image):
height, width = image.shape[:2]
text = f'Width: {width} Height: {height}'
text = f"Width: {width} Height: {height}"
# Define the font, scale, color, and thickness
font = cv2.FONT_HERSHEY_SIMPLEX

View File

@ -1,10 +1,18 @@
from copy import deepcopy
import enum
from typing import Union
import numpy as np
from copy import deepcopy
from dynamixel_sdk import PacketHandler, PortHandler, COMM_SUCCESS, GroupSyncRead, GroupSyncWrite
from dynamixel_sdk import DXL_HIBYTE, DXL_HIWORD, DXL_LOBYTE, DXL_LOWORD
import numpy as np
from dynamixel_sdk import (
COMM_SUCCESS,
DXL_HIBYTE,
DXL_HIWORD,
DXL_LOBYTE,
DXL_LOWORD,
GroupSyncRead,
GroupSyncWrite,
PacketHandler,
PortHandler,
)
PROTOCOL_VERSION = 2.0
BAUD_RATE = 1_000_000
@ -69,12 +77,12 @@ X_SERIES_CONTROL_TABLE = {
"Velocity_Trajectory": (136, 4),
"Position_Trajectory": (140, 4),
"Present_Input_Voltage": (144, 2),
"Present_Temperature": (146, 1)
"Present_Temperature": (146, 1),
}
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
#CONVERT_POSITION_TO_ANGLE_REQUIRED = ["Goal_Position", "Present_Position"]
# CONVERT_POSITION_TO_ANGLE_REQUIRED = ["Goal_Position", "Present_Position"]
CONVERT_POSITION_TO_ANGLE_REQUIRED = []
MODEL_CONTROL_TABLE = {
@ -86,6 +94,7 @@ MODEL_CONTROL_TABLE = {
"xm540-w270": X_SERIES_CONTROL_TABLE,
}
def uint32_to_int32(values: np.ndarray):
"""
Convert an unsigned 32-bit integer array to a signed 32-bit integer array.
@ -95,6 +104,7 @@ def uint32_to_int32(values: np.ndarray):
values[i] = values[i] - 4294967296
return values
def int32_to_uint32(values: np.ndarray):
"""
Convert a signed 32-bit integer array to an unsigned 32-bit integer array.
@ -104,12 +114,14 @@ def int32_to_uint32(values: np.ndarray):
values[i] = values[i] + 4294967296
return values
def motor_position_to_angle(position: np.ndarray) -> np.ndarray:
"""
Convert from motor position in [-2048, 2048] to radian in [-pi, pi]
"""
return (position / 2048) * 3.14
def motor_angle_to_position(angle: np.ndarray) -> np.ndarray:
"""
Convert from radian in [-pi, pi] to motor position in [-2048, 2048]
@ -134,7 +146,7 @@ def motor_angle_to_position(angle: np.ndarray) -> np.ndarray:
def get_group_sync_key(data_name, motor_names):
group_key = f"{data_name}_" + "_".join([name for name in motor_names])
group_key = f"{data_name}_" + "_".join(motor_names)
return group_key
@ -158,9 +170,12 @@ class DriveMode(enum.Enum):
class DynamixelMotorsBus:
def __init__(self, port: str, motors: dict[str, tuple[int, str]],
extra_model_control_table: dict[str, list[tuple]] | None = None):
def __init__(
self,
port: str,
motors: dict[str, tuple[int, str]],
extra_model_control_table: dict[str, list[tuple]] | None = None,
):
self.port = port
self.motors = motors
@ -307,9 +322,11 @@ class DynamixelMotorsBus:
init_group = data_name not in self.group_readers
if init_group:
self.group_writers[group_key] = GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
self.group_writers[group_key] = GroupSyncWrite(
self.port_handler, self.packet_handler, addr, bytes
)
for idx, value in zip(motor_ids, values):
for idx, value in zip(motor_ids, values, strict=False):
if bytes == 1:
data = [
DXL_LOBYTE(DXL_LOWORD(value)),
@ -329,7 +346,8 @@ class DynamixelMotorsBus:
else:
raise NotImplementedError(
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
f"{bytes} is provided instead.")
f"{bytes} is provided instead."
)
if init_group:
self.group_writers[group_key].addParam(idx, data)

View File

@ -1,5 +1,6 @@
from typing import Protocol
class MotorsBus(Protocol):
def motor_names(self): ...
def set_calibration(self): ...

View File

@ -1,11 +1,8 @@
def make_robot(name):
if name == "koch":
from lerobot.common.robot_devices.robots.koch import KochRobot
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
from lerobot.common.robot_devices.robots.koch import KochRobot
robot = KochRobot(
leader_arms={
@ -38,7 +35,7 @@ def make_robot(name):
},
cameras={
"main": OpenCVCamera(1, fps=30, width=640, height=480),
}
},
)
else:
raise ValueError(f"Robot '{name}' not found.")

View File

@ -1,14 +1,18 @@
import copy
import pickle
from dataclasses import dataclass, field, replace
from pathlib import Path
import pickle
import numpy as np
import torch
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.motors.dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode, motor_position_to_angle
from lerobot.common.robot_devices.motors.utils import MotorsBus
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.motors.dynamixel import (
DriveMode,
DynamixelMotorsBus,
OperatingMode,
TorqueMode,
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
########################################################################
# Calibration logic
@ -22,6 +26,7 @@ TARGET_HORIZONTAL_POSITION = np.array([0, -1024, 1024, 0, -1024, 0])
TARGET_90_DEGREE_POSITION = np.array([1024, 0, 0, 1024, 0, -1024])
GRIPPER_OPEN = np.array([-400])
def apply_homing_offset(values: np.array, homing_offset: np.array) -> np.array:
for i in range(len(values)):
if values[i] is not None:
@ -35,18 +40,21 @@ def apply_drive_mode(values: np.array, drive_mode: np.array) -> np.array:
values[i] = -values[i]
return values
def apply_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
values = apply_drive_mode(values, drive_mode)
values = apply_homing_offset(values, homing_offset)
return values
def revert_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
"""
Transform working position into real position for the robot.
"""
values = apply_homing_offset(values, np.array([
-homing_offset if homing_offset is not None else None for homing_offset in homing_offset
]))
values = apply_homing_offset(
values,
np.array([-homing_offset if homing_offset is not None else None for homing_offset in homing_offset]),
)
values = apply_drive_mode(values, drive_mode)
return values
@ -73,15 +81,20 @@ def compute_corrections(positions: np.array, drive_mode: list[bool], target_posi
def compute_nearest_rounded_positions(positions: np.array) -> np.array:
return np.array(
[round(positions[i] / 1024) * 1024 if positions[i] is not None else None for i in range(len(positions))])
[
round(positions[i] / 1024) * 1024 if positions[i] is not None else None
for i in range(len(positions))
]
)
def compute_homing_offset(arm: DynamixelMotorsBus, drive_mode: list[bool], target_position: np.array) -> np.array:
def compute_homing_offset(
arm: DynamixelMotorsBus, drive_mode: list[bool], target_position: np.array
) -> np.array:
# Get the present positions of the servos
present_positions = apply_calibration(
arm.read("Present_Position"),
np.array([0, 0, 0, 0, 0, 0]),
drive_mode)
arm.read("Present_Position"), np.array([0, 0, 0, 0, 0, 0]), drive_mode
)
nearest_positions = compute_nearest_rounded_positions(present_positions)
correction = compute_corrections(nearest_positions, drive_mode, target_position)
@ -91,9 +104,8 @@ def compute_homing_offset(arm: DynamixelMotorsBus, drive_mode: list[bool], targe
def compute_drive_mode(arm: DynamixelMotorsBus, offset: np.array):
# Get current positions
present_positions = apply_calibration(
arm.read("Present_Position"),
offset,
np.array([False, False, False, False, False, False]))
arm.read("Present_Position"), offset, np.array([False, False, False, False, False, False])
)
nearest_positions = compute_nearest_rounded_positions(present_positions)
@ -131,7 +143,9 @@ def run_arm_calibration(arm: MotorsBus, name: str):
print(f"Please move the '{name}' arm to the horizontal position (gripper fully closed)")
input("Press Enter to continue...")
horizontal_homing_offset = compute_homing_offset(arm, [False, False, False, False, False, False], TARGET_HORIZONTAL_POSITION)
horizontal_homing_offset = compute_homing_offset(
arm, [False, False, False, False, False, False], TARGET_HORIZONTAL_POSITION
)
# TODO(rcadene): document what position 2 mean
print(f"Please move the '{name}' arm to the 90 degree position (gripper fully open)")
@ -159,6 +173,7 @@ def run_arm_calibration(arm: MotorsBus, name: str):
# Alexander Koch robot arm
########################################################################
@dataclass
class KochRobotConfig:
"""
@ -201,12 +216,11 @@ class KochRobotConfig:
),
}
)
cameras: dict[str, Camera] = field(
default_factory=lambda: {}
)
cameras: dict[str, Camera] = field(default_factory=lambda: {})
class KochRobot():
""" Tau Robotics: https://tau-robotics.com
class KochRobot:
"""Tau Robotics: https://tau-robotics.com
Example of usage:
```python
@ -214,7 +228,12 @@ class KochRobot():
```
"""
def __init__(self, config: KochRobotConfig | None = None, calibration_path: Path = ".cache/calibration/koch.pkl", **kwargs):
def __init__(
self,
config: KochRobotConfig | None = None,
calibration_path: Path = ".cache/calibration/koch.pkl",
**kwargs,
):
if config is None:
config = KochRobotConfig()
# Overwrite config arguments using kwargs
@ -234,13 +253,13 @@ class KochRobot():
for name in self.leader_arms:
reset_arm(self.leader_arms[name])
with open(self.calibration_path, 'rb') as f:
with open(self.calibration_path, "rb") as f:
calibration = pickle.load(f)
else:
calibration = self.run_calibration()
self.calibration_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.calibration_path, 'wb') as f:
with open(self.calibration_path, "wb") as f:
pickle.dump(calibration, f)
for name in self.follower_arms:
@ -275,7 +294,9 @@ class KochRobot():
return calibration
def teleop_step(self, record_data=False) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
def teleop_step(
self, record_data=False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
# Prepare to assign the positions of the leader to the follower
leader_pos = {}
for name in self.leader_arms:

View File

@ -1,5 +1,6 @@
from typing import Protocol
class Robot(Protocol):
def init_teleop(self): ...
def run_calibration(self): ...

View File

@ -62,20 +62,22 @@ python lerobot/scripts/control_robot.py run_policy \
"""
import argparse
from contextlib import nullcontext
import concurrent.futures
import os
from pathlib import Path
import shutil
import time
from contextlib import nullcontext
from pathlib import Path
from PIL import Image
from omegaconf import DictConfig
import torch
from omegaconf import DictConfig
from PIL import Image
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
from lerobot.common.datasets.utils import calculate_episode_data_index, load_hf_dataset
from lerobot.common.datasets.utils import calculate_episode_data_index
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.factory import make_robot
@ -83,14 +85,12 @@ from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path
from lerobot.scripts.push_dataset_to_hub import save_meta_data
from lerobot.scripts.robot_controls.record_dataset import record_dataset
import concurrent.futures
########################################################################################
# Utilities
########################################################################################
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
img = Image.fromarray(img_tensor.numpy())
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
@ -106,15 +106,18 @@ def busy_wait(seconds):
while time.perf_counter() < end_time:
pass
def none_or_int(value):
if value == 'None':
if value == "None":
return None
return int(value)
########################################################################################
# Control modes
########################################################################################
def teleoperate(robot: Robot, fps: int | None = None):
robot.init_teleop()
@ -123,14 +126,24 @@ def teleoperate(robot: Robot, fps: int | None = None):
robot.teleop_step()
if fps is not None:
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="lerobot/debug", warmup_time_s=2, episode_time_s=10, num_episodes=50, video=True, run_compute_stats=True):
def record_dataset(
robot: Robot,
fps: int | None = None,
root="data",
repo_id="lerobot/debug",
warmup_time_s=2,
episode_time_s=10,
num_episodes=50,
video=True,
run_compute_stats=True,
):
if not video:
raise NotImplementedError()
@ -143,7 +156,6 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
videos_dir = local_dir / "videos"
videos_dir.mkdir(parents=True, exist_ok=True)
start_time = time.perf_counter()
is_warmup_print = False
@ -154,7 +166,6 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
# Using `with` ensures the program exists smoothly if an execption is raised.
with concurrent.futures.ThreadPoolExecutor() as executor:
for episode_index in range(num_episodes):
ep_dict = {}
frame_index = 0
@ -169,10 +180,10 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
timestamp = time.perf_counter() - start_time
if timestamp < warmup_time_s:
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f} (Warmup)")
continue
@ -199,10 +210,10 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
frame_index += 1
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
if timestamp > episode_time_s - warmup_time_s:
@ -269,10 +280,7 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
info=info,
videos_dir=videos_dir,
)
if run_compute_stats:
stats = compute_stats(lerobot_dataset)
else:
stats = {}
stats = compute_stats(lerobot_dataset) if run_compute_stats else {}
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train"))
@ -303,10 +311,10 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
action = items[idx]["action"]
robot.send_action(action)
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
@ -327,15 +335,18 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
observation = robot.capture_observation()
with torch.inference_mode(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
):
action = policy.select_action(observation)
robot.send_action(action)
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
@ -345,32 +356,46 @@ if __name__ == "__main__":
# Set common options for all the subparsers
base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument("--robot", type=str, default="koch", help="Name of the robot provided to the `make_robot(name)` factory function.")
base_parser.add_argument(
"--robot",
type=str,
default="koch",
help="Name of the robot provided to the `make_robot(name)` factory function.",
)
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
parser_teleop.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
parser_teleop.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_record = subparsers.add_parser("record_dataset", parents=[base_parser])
parser_record.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
parser_record.add_argument('--root', type=Path, default="data", help='')
parser_record.add_argument('--repo-id', type=str, default="lerobot/test", help='')
parser_record.add_argument('--warmup-time-s', type=int, default=2, help='')
parser_record.add_argument('--episode-time-s', type=int, default=10, help='')
parser_record.add_argument('--num-episodes', type=int, default=50, help='')
parser_record.add_argument('--run-compute-stats', type=int, default=1, help='')
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_record.add_argument("--root", type=Path, default="data", help="")
parser_record.add_argument("--repo-id", type=str, default="lerobot/test", help="")
parser_record.add_argument("--warmup-time-s", type=int, default=2, help="")
parser_record.add_argument("--episode-time-s", type=int, default=10, help="")
parser_record.add_argument("--num-episodes", type=int, default=50, help="")
parser_record.add_argument("--run-compute-stats", type=int, default=1, help="")
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
parser_replay.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
parser_replay.add_argument('--root', type=Path, default="data", help='')
parser_replay.add_argument('--repo-id', type=str, default="lerobot/test", help='')
parser_replay.add_argument('--episode', type=int, default=0, help='')
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_replay.add_argument("--root", type=Path, default="data", help="")
parser_replay.add_argument("--repo-id", type=str, default="lerobot/test", help="")
parser_replay.add_argument("--episode", type=int, default=0, help="")
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
parser_policy.add_argument('-p', '--pretrained-policy-name-or-path', type=str,
parser_policy.add_argument(
"-p",
"--pretrained-policy-name-or-path",
type=str,
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`."
)
),
)
parser_policy.add_argument(
"overrides",

View File

@ -580,9 +580,7 @@ def main(
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
try:
pretrained_policy_path = Path(
snapshot_download(pretrained_policy_name_or_path, revision=revision)
)
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
@ -644,7 +642,9 @@ if __name__ == "__main__":
if args.pretrained_policy_name_or_path is None:
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
else:
pretrained_policy_path = get_pretrained_policy_path(args.pretrained_policy_name_or_path, revision=args.revision)
pretrained_policy_path = get_pretrained_policy_path(
args.pretrained_policy_name_or_path, revision=args.revision
)
main(
pretrained_policy_path=pretrained_policy_path,