Style
This commit is contained in:
parent
47aac0dff7
commit
8a7aa50e97
|
@ -1,10 +1,9 @@
|
|||
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
@ -23,10 +22,13 @@ def find_camera_indices(raise_when_empty=False, max_index_search_range=60):
|
|||
camera_ids.append(camera_idx)
|
||||
|
||||
if raise_when_empty and len(camera_ids) == 0:
|
||||
raise OSError("Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, or your camera driver, or make sure your camera is compatible with opencv2.")
|
||||
raise OSError(
|
||||
"Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, or your camera driver, or make sure your camera is compatible with opencv2."
|
||||
)
|
||||
|
||||
return camera_ids
|
||||
|
||||
|
||||
def benchmark_cameras(cameras, out_dir=None, save_images=False, num_warmup_frames=4):
|
||||
if out_dir:
|
||||
out_dir = Path(out_dir)
|
||||
|
@ -50,7 +52,7 @@ def benchmark_cameras(cameras, out_dir=None, save_images=False, num_warmup_frame
|
|||
print(f"Write to {image_path}")
|
||||
save_color_image(color_image, image_path, write_shape=True)
|
||||
|
||||
dt_s = (time.time() - now)
|
||||
dt_s = time.time() - now
|
||||
dt_ms = dt_s * 1000
|
||||
freq = 1 / dt_s
|
||||
print(f"Latency (ms): {dt_ms:.2f}\tFrequency: {freq:.2f}")
|
||||
|
@ -73,14 +75,14 @@ class OpenCVCameraConfig:
|
|||
OpenCVCameraConfig(30, 1280, 720)
|
||||
```
|
||||
"""
|
||||
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color: str = "rgb"
|
||||
|
||||
|
||||
|
||||
class OpenCVCamera():
|
||||
class OpenCVCamera:
|
||||
# TODO(rcadene): improve dosctring
|
||||
"""
|
||||
https://docs.opencv.org/4.x/d0/da7/videoio_overview.html
|
||||
|
@ -93,6 +95,7 @@ class OpenCVCamera():
|
|||
color_image = camera.capture_image()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, camera_index: int, config: OpenCVCameraConfig | None = None, **kwargs):
|
||||
if config is None:
|
||||
config = OpenCVCameraConfig()
|
||||
|
@ -109,7 +112,9 @@ class OpenCVCamera():
|
|||
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {self.color} is provided.")
|
||||
|
||||
if self.camera_index is None:
|
||||
raise ValueError(f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {camera_index} is provided instead.")
|
||||
raise ValueError(
|
||||
f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {camera_index} is provided instead."
|
||||
)
|
||||
|
||||
self.camera = None
|
||||
self.is_connected = False
|
||||
|
@ -134,7 +139,9 @@ class OpenCVCamera():
|
|||
if not is_camera_open:
|
||||
# Verify that the provided `camera_index` is valid before printing the traceback
|
||||
if self.camera_index not in find_camera_indices():
|
||||
raise ValueError(f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {self.camera_index} is provided instead.")
|
||||
raise ValueError(
|
||||
f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {self.camera_index} is provided instead."
|
||||
)
|
||||
|
||||
raise OSError(f"Can't access camera {self.camera_index}.")
|
||||
|
||||
|
@ -155,11 +162,17 @@ class OpenCVCamera():
|
|||
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
|
||||
if self.fps and self.fps != actual_fps:
|
||||
raise OSError(f"Can't set {self.fps=} for camera {self.camera_index}. Actual value is {actual_fps}.")
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for camera {self.camera_index}. Actual value is {actual_fps}."
|
||||
)
|
||||
if self.width and self.width != actual_width:
|
||||
raise OSError(f"Can't set {self.width=} for camera {self.camera_index}. Actual value is {actual_width}.")
|
||||
raise OSError(
|
||||
f"Can't set {self.width=} for camera {self.camera_index}. Actual value is {actual_width}."
|
||||
)
|
||||
if self.height and self.height != actual_height:
|
||||
raise OSError(f"Can't set {self.height=} for camera {self.camera_index}. Actual value is {actual_height}.")
|
||||
raise OSError(
|
||||
f"Can't set {self.height=} for camera {self.camera_index}. Actual value is {actual_height}."
|
||||
)
|
||||
|
||||
self.is_connected = True
|
||||
self.t.start()
|
||||
|
@ -172,10 +185,7 @@ class OpenCVCamera():
|
|||
if not ret:
|
||||
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
|
||||
|
||||
if temporary_color is None:
|
||||
requested_color = self.color
|
||||
else:
|
||||
requested_color = temporary_color
|
||||
requested_color = self.color if temporary_color is None else temporary_color
|
||||
|
||||
if requested_color not in ["rgb", "bgr"]:
|
||||
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {requested_color} is provided.")
|
||||
|
@ -215,6 +225,7 @@ def save_images_config(config: OpenCVCameraConfig, out_dir: Path):
|
|||
out_dir = out_dir.parent / f"{out_dir.name}_{config.width}x{config.height}_{config.fps}"
|
||||
benchmark_cameras(cameras, out_dir, save_images=True)
|
||||
|
||||
|
||||
def benchmark_config(config: OpenCVCameraConfig, camera_ids: list[int]):
|
||||
cameras = [OpenCVCamera(idx, config) for idx in camera_ids]
|
||||
benchmark_cameras(cameras)
|
||||
|
@ -222,7 +233,7 @@ def benchmark_config(config: OpenCVCameraConfig, camera_ids: list[int]):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--mode", type=str, choices=["save_images", 'benchmark'], default="save_images")
|
||||
parser.add_argument("--mode", type=str, choices=["save_images", "benchmark"], default="save_images")
|
||||
parser.add_argument("--camera-ids", type=int, nargs="*", default=[16, 4, 22, 10])
|
||||
parser.add_argument("--fps", type=int, default=30)
|
||||
parser.add_argument("--width", type=str, default=640)
|
||||
|
@ -242,8 +253,3 @@ if __name__ == "__main__":
|
|||
benchmark_config(config, args.camera_ids)
|
||||
else:
|
||||
raise ValueError(args.mode)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,15 +1,13 @@
|
|||
|
||||
from pathlib import Path
|
||||
import time
|
||||
import cv2
|
||||
from typing import Protocol
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def write_shape_on_image_inplace(image):
|
||||
height, width = image.shape[:2]
|
||||
text = f'Width: {width} Height: {height}'
|
||||
text = f"Width: {width} Height: {height}"
|
||||
|
||||
# Define the font, scale, color, and thickness
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
|
|
|
@ -1,10 +1,18 @@
|
|||
from copy import deepcopy
|
||||
import enum
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
||||
from dynamixel_sdk import PacketHandler, PortHandler, COMM_SUCCESS, GroupSyncRead, GroupSyncWrite
|
||||
from dynamixel_sdk import DXL_HIBYTE, DXL_HIWORD, DXL_LOBYTE, DXL_LOWORD
|
||||
import numpy as np
|
||||
from dynamixel_sdk import (
|
||||
COMM_SUCCESS,
|
||||
DXL_HIBYTE,
|
||||
DXL_HIWORD,
|
||||
DXL_LOBYTE,
|
||||
DXL_LOWORD,
|
||||
GroupSyncRead,
|
||||
GroupSyncWrite,
|
||||
PacketHandler,
|
||||
PortHandler,
|
||||
)
|
||||
|
||||
PROTOCOL_VERSION = 2.0
|
||||
BAUD_RATE = 1_000_000
|
||||
|
@ -69,12 +77,12 @@ X_SERIES_CONTROL_TABLE = {
|
|||
"Velocity_Trajectory": (136, 4),
|
||||
"Position_Trajectory": (140, 4),
|
||||
"Present_Input_Voltage": (144, 2),
|
||||
"Present_Temperature": (146, 1)
|
||||
"Present_Temperature": (146, 1),
|
||||
}
|
||||
|
||||
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
#CONVERT_POSITION_TO_ANGLE_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
# CONVERT_POSITION_TO_ANGLE_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
CONVERT_POSITION_TO_ANGLE_REQUIRED = []
|
||||
|
||||
MODEL_CONTROL_TABLE = {
|
||||
|
@ -86,6 +94,7 @@ MODEL_CONTROL_TABLE = {
|
|||
"xm540-w270": X_SERIES_CONTROL_TABLE,
|
||||
}
|
||||
|
||||
|
||||
def uint32_to_int32(values: np.ndarray):
|
||||
"""
|
||||
Convert an unsigned 32-bit integer array to a signed 32-bit integer array.
|
||||
|
@ -95,6 +104,7 @@ def uint32_to_int32(values: np.ndarray):
|
|||
values[i] = values[i] - 4294967296
|
||||
return values
|
||||
|
||||
|
||||
def int32_to_uint32(values: np.ndarray):
|
||||
"""
|
||||
Convert a signed 32-bit integer array to an unsigned 32-bit integer array.
|
||||
|
@ -104,12 +114,14 @@ def int32_to_uint32(values: np.ndarray):
|
|||
values[i] = values[i] + 4294967296
|
||||
return values
|
||||
|
||||
|
||||
def motor_position_to_angle(position: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert from motor position in [-2048, 2048] to radian in [-pi, pi]
|
||||
"""
|
||||
return (position / 2048) * 3.14
|
||||
|
||||
|
||||
def motor_angle_to_position(angle: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert from radian in [-pi, pi] to motor position in [-2048, 2048]
|
||||
|
@ -134,7 +146,7 @@ def motor_angle_to_position(angle: np.ndarray) -> np.ndarray:
|
|||
|
||||
|
||||
def get_group_sync_key(data_name, motor_names):
|
||||
group_key = f"{data_name}_" + "_".join([name for name in motor_names])
|
||||
group_key = f"{data_name}_" + "_".join(motor_names)
|
||||
return group_key
|
||||
|
||||
|
||||
|
@ -158,9 +170,12 @@ class DriveMode(enum.Enum):
|
|||
|
||||
|
||||
class DynamixelMotorsBus:
|
||||
|
||||
def __init__(self, port: str, motors: dict[str, tuple[int, str]],
|
||||
extra_model_control_table: dict[str, list[tuple]] | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, tuple[int, str]],
|
||||
extra_model_control_table: dict[str, list[tuple]] | None = None,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
|
||||
|
@ -307,9 +322,11 @@ class DynamixelMotorsBus:
|
|||
|
||||
init_group = data_name not in self.group_readers
|
||||
if init_group:
|
||||
self.group_writers[group_key] = GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
||||
self.group_writers[group_key] = GroupSyncWrite(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
)
|
||||
|
||||
for idx, value in zip(motor_ids, values):
|
||||
for idx, value in zip(motor_ids, values, strict=False):
|
||||
if bytes == 1:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
|
@ -329,7 +346,8 @@ class DynamixelMotorsBus:
|
|||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead.")
|
||||
f"{bytes} is provided instead."
|
||||
)
|
||||
|
||||
if init_group:
|
||||
self.group_writers[group_key].addParam(idx, data)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Protocol
|
||||
|
||||
|
||||
class MotorsBus(Protocol):
|
||||
def motor_names(self): ...
|
||||
def set_calibration(self): ...
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
|
||||
|
||||
def make_robot(name):
|
||||
|
||||
if name == "koch":
|
||||
from lerobot.common.robot_devices.robots.koch import KochRobot
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
from lerobot.common.robot_devices.robots.koch import KochRobot
|
||||
|
||||
robot = KochRobot(
|
||||
leader_arms={
|
||||
|
@ -38,7 +35,7 @@ def make_robot(name):
|
|||
},
|
||||
cameras={
|
||||
"main": OpenCVCamera(1, fps=30, width=640, height=480),
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Robot '{name}' not found.")
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
import copy
|
||||
import pickle
|
||||
from dataclasses import dataclass, field, replace
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode, motor_position_to_angle
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.motors.dynamixel import (
|
||||
DriveMode,
|
||||
DynamixelMotorsBus,
|
||||
OperatingMode,
|
||||
TorqueMode,
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
|
||||
########################################################################
|
||||
# Calibration logic
|
||||
|
@ -22,6 +26,7 @@ TARGET_HORIZONTAL_POSITION = np.array([0, -1024, 1024, 0, -1024, 0])
|
|||
TARGET_90_DEGREE_POSITION = np.array([1024, 0, 0, 1024, 0, -1024])
|
||||
GRIPPER_OPEN = np.array([-400])
|
||||
|
||||
|
||||
def apply_homing_offset(values: np.array, homing_offset: np.array) -> np.array:
|
||||
for i in range(len(values)):
|
||||
if values[i] is not None:
|
||||
|
@ -35,18 +40,21 @@ def apply_drive_mode(values: np.array, drive_mode: np.array) -> np.array:
|
|||
values[i] = -values[i]
|
||||
return values
|
||||
|
||||
|
||||
def apply_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
|
||||
values = apply_drive_mode(values, drive_mode)
|
||||
values = apply_homing_offset(values, homing_offset)
|
||||
return values
|
||||
|
||||
|
||||
def revert_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
|
||||
"""
|
||||
Transform working position into real position for the robot.
|
||||
"""
|
||||
values = apply_homing_offset(values, np.array([
|
||||
-homing_offset if homing_offset is not None else None for homing_offset in homing_offset
|
||||
]))
|
||||
values = apply_homing_offset(
|
||||
values,
|
||||
np.array([-homing_offset if homing_offset is not None else None for homing_offset in homing_offset]),
|
||||
)
|
||||
values = apply_drive_mode(values, drive_mode)
|
||||
return values
|
||||
|
||||
|
@ -73,15 +81,20 @@ def compute_corrections(positions: np.array, drive_mode: list[bool], target_posi
|
|||
|
||||
def compute_nearest_rounded_positions(positions: np.array) -> np.array:
|
||||
return np.array(
|
||||
[round(positions[i] / 1024) * 1024 if positions[i] is not None else None for i in range(len(positions))])
|
||||
[
|
||||
round(positions[i] / 1024) * 1024 if positions[i] is not None else None
|
||||
for i in range(len(positions))
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def compute_homing_offset(arm: DynamixelMotorsBus, drive_mode: list[bool], target_position: np.array) -> np.array:
|
||||
def compute_homing_offset(
|
||||
arm: DynamixelMotorsBus, drive_mode: list[bool], target_position: np.array
|
||||
) -> np.array:
|
||||
# Get the present positions of the servos
|
||||
present_positions = apply_calibration(
|
||||
arm.read("Present_Position"),
|
||||
np.array([0, 0, 0, 0, 0, 0]),
|
||||
drive_mode)
|
||||
arm.read("Present_Position"), np.array([0, 0, 0, 0, 0, 0]), drive_mode
|
||||
)
|
||||
|
||||
nearest_positions = compute_nearest_rounded_positions(present_positions)
|
||||
correction = compute_corrections(nearest_positions, drive_mode, target_position)
|
||||
|
@ -91,9 +104,8 @@ def compute_homing_offset(arm: DynamixelMotorsBus, drive_mode: list[bool], targe
|
|||
def compute_drive_mode(arm: DynamixelMotorsBus, offset: np.array):
|
||||
# Get current positions
|
||||
present_positions = apply_calibration(
|
||||
arm.read("Present_Position"),
|
||||
offset,
|
||||
np.array([False, False, False, False, False, False]))
|
||||
arm.read("Present_Position"), offset, np.array([False, False, False, False, False, False])
|
||||
)
|
||||
|
||||
nearest_positions = compute_nearest_rounded_positions(present_positions)
|
||||
|
||||
|
@ -131,7 +143,9 @@ def run_arm_calibration(arm: MotorsBus, name: str):
|
|||
print(f"Please move the '{name}' arm to the horizontal position (gripper fully closed)")
|
||||
input("Press Enter to continue...")
|
||||
|
||||
horizontal_homing_offset = compute_homing_offset(arm, [False, False, False, False, False, False], TARGET_HORIZONTAL_POSITION)
|
||||
horizontal_homing_offset = compute_homing_offset(
|
||||
arm, [False, False, False, False, False, False], TARGET_HORIZONTAL_POSITION
|
||||
)
|
||||
|
||||
# TODO(rcadene): document what position 2 mean
|
||||
print(f"Please move the '{name}' arm to the 90 degree position (gripper fully open)")
|
||||
|
@ -159,6 +173,7 @@ def run_arm_calibration(arm: MotorsBus, name: str):
|
|||
# Alexander Koch robot arm
|
||||
########################################################################
|
||||
|
||||
|
||||
@dataclass
|
||||
class KochRobotConfig:
|
||||
"""
|
||||
|
@ -201,12 +216,11 @@ class KochRobotConfig:
|
|||
),
|
||||
}
|
||||
)
|
||||
cameras: dict[str, Camera] = field(
|
||||
default_factory=lambda: {}
|
||||
)
|
||||
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
||||
|
||||
class KochRobot():
|
||||
""" Tau Robotics: https://tau-robotics.com
|
||||
|
||||
class KochRobot:
|
||||
"""Tau Robotics: https://tau-robotics.com
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
|
@ -214,7 +228,12 @@ class KochRobot():
|
|||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: KochRobotConfig | None = None, calibration_path: Path = ".cache/calibration/koch.pkl", **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
config: KochRobotConfig | None = None,
|
||||
calibration_path: Path = ".cache/calibration/koch.pkl",
|
||||
**kwargs,
|
||||
):
|
||||
if config is None:
|
||||
config = KochRobotConfig()
|
||||
# Overwrite config arguments using kwargs
|
||||
|
@ -234,13 +253,13 @@ class KochRobot():
|
|||
for name in self.leader_arms:
|
||||
reset_arm(self.leader_arms[name])
|
||||
|
||||
with open(self.calibration_path, 'rb') as f:
|
||||
with open(self.calibration_path, "rb") as f:
|
||||
calibration = pickle.load(f)
|
||||
else:
|
||||
calibration = self.run_calibration()
|
||||
|
||||
self.calibration_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.calibration_path, 'wb') as f:
|
||||
with open(self.calibration_path, "wb") as f:
|
||||
pickle.dump(calibration, f)
|
||||
|
||||
for name in self.follower_arms:
|
||||
|
@ -275,7 +294,9 @@ class KochRobot():
|
|||
|
||||
return calibration
|
||||
|
||||
def teleop_step(self, record_data=False) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
def teleop_step(
|
||||
self, record_data=False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
# Prepare to assign the positions of the leader to the follower
|
||||
leader_pos = {}
|
||||
for name in self.leader_arms:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Protocol
|
||||
|
||||
|
||||
class Robot(Protocol):
|
||||
def init_teleop(self): ...
|
||||
def run_calibration(self): ...
|
||||
|
|
|
@ -62,20 +62,22 @@ python lerobot/scripts/control_robot.py run_policy \
|
|||
"""
|
||||
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
import concurrent.futures
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from omegaconf import DictConfig
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, load_hf_dataset
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
|
@ -83,14 +85,12 @@ from lerobot.common.robot_devices.robots.utils import Robot
|
|||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
|
||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||
from lerobot.scripts.push_dataset_to_hub import save_meta_data
|
||||
from lerobot.scripts.robot_controls.record_dataset import record_dataset
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Utilities
|
||||
########################################################################################
|
||||
|
||||
|
||||
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
|
||||
img = Image.fromarray(img_tensor.numpy())
|
||||
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
||||
|
@ -106,15 +106,18 @@ def busy_wait(seconds):
|
|||
while time.perf_counter() < end_time:
|
||||
pass
|
||||
|
||||
|
||||
def none_or_int(value):
|
||||
if value == 'None':
|
||||
if value == "None":
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Control modes
|
||||
########################################################################################
|
||||
|
||||
|
||||
def teleoperate(robot: Robot, fps: int | None = None):
|
||||
robot.init_teleop()
|
||||
|
||||
|
@ -123,14 +126,24 @@ def teleoperate(robot: Robot, fps: int | None = None):
|
|||
robot.teleop_step()
|
||||
|
||||
if fps is not None:
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||
|
||||
|
||||
def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="lerobot/debug", warmup_time_s=2, episode_time_s=10, num_episodes=50, video=True, run_compute_stats=True):
|
||||
def record_dataset(
|
||||
robot: Robot,
|
||||
fps: int | None = None,
|
||||
root="data",
|
||||
repo_id="lerobot/debug",
|
||||
warmup_time_s=2,
|
||||
episode_time_s=10,
|
||||
num_episodes=50,
|
||||
video=True,
|
||||
run_compute_stats=True,
|
||||
):
|
||||
if not video:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -143,7 +156,6 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
|||
videos_dir = local_dir / "videos"
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
is_warmup_print = False
|
||||
|
@ -154,7 +166,6 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
|||
# Using `with` ensures the program exists smoothly if an execption is raised.
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
for episode_index in range(num_episodes):
|
||||
|
||||
ep_dict = {}
|
||||
frame_index = 0
|
||||
|
||||
|
@ -169,10 +180,10 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
|||
timestamp = time.perf_counter() - start_time
|
||||
|
||||
if timestamp < warmup_time_s:
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f} (Warmup)")
|
||||
continue
|
||||
|
||||
|
@ -199,10 +210,10 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
|||
|
||||
frame_index += 1
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||
|
||||
if timestamp > episode_time_s - warmup_time_s:
|
||||
|
@ -269,10 +280,7 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
|||
info=info,
|
||||
videos_dir=videos_dir,
|
||||
)
|
||||
if run_compute_stats:
|
||||
stats = compute_stats(lerobot_dataset)
|
||||
else:
|
||||
stats = {}
|
||||
stats = compute_stats(lerobot_dataset) if run_compute_stats else {}
|
||||
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
|
@ -303,10 +311,10 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
|
|||
action = items[idx]["action"]
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||
|
||||
|
||||
|
@ -327,15 +335,18 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
|
|||
|
||||
observation = robot.capture_observation()
|
||||
|
||||
with torch.inference_mode(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
|
||||
):
|
||||
action = policy.select_action(observation)
|
||||
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
dt_s = time.perf_counter() - now
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||
|
||||
|
||||
|
@ -345,32 +356,46 @@ if __name__ == "__main__":
|
|||
|
||||
# Set common options for all the subparsers
|
||||
base_parser = argparse.ArgumentParser(add_help=False)
|
||||
base_parser.add_argument("--robot", type=str, default="koch", help="Name of the robot provided to the `make_robot(name)` factory function.")
|
||||
base_parser.add_argument(
|
||||
"--robot",
|
||||
type=str,
|
||||
default="koch",
|
||||
help="Name of the robot provided to the `make_robot(name)` factory function.",
|
||||
)
|
||||
|
||||
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
|
||||
parser_teleop.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
|
||||
parser_teleop.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
|
||||
parser_record = subparsers.add_parser("record_dataset", parents=[base_parser])
|
||||
parser_record.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
|
||||
parser_record.add_argument('--root', type=Path, default="data", help='')
|
||||
parser_record.add_argument('--repo-id', type=str, default="lerobot/test", help='')
|
||||
parser_record.add_argument('--warmup-time-s', type=int, default=2, help='')
|
||||
parser_record.add_argument('--episode-time-s', type=int, default=10, help='')
|
||||
parser_record.add_argument('--num-episodes', type=int, default=50, help='')
|
||||
parser_record.add_argument('--run-compute-stats', type=int, default=1, help='')
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_record.add_argument("--root", type=Path, default="data", help="")
|
||||
parser_record.add_argument("--repo-id", type=str, default="lerobot/test", help="")
|
||||
parser_record.add_argument("--warmup-time-s", type=int, default=2, help="")
|
||||
parser_record.add_argument("--episode-time-s", type=int, default=10, help="")
|
||||
parser_record.add_argument("--num-episodes", type=int, default=50, help="")
|
||||
parser_record.add_argument("--run-compute-stats", type=int, default=1, help="")
|
||||
|
||||
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
|
||||
parser_replay.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
|
||||
parser_replay.add_argument('--root', type=Path, default="data", help='')
|
||||
parser_replay.add_argument('--repo-id', type=str, default="lerobot/test", help='')
|
||||
parser_replay.add_argument('--episode', type=int, default=0, help='')
|
||||
parser_replay.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_replay.add_argument("--root", type=Path, default="data", help="")
|
||||
parser_replay.add_argument("--repo-id", type=str, default="lerobot/test", help="")
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="")
|
||||
|
||||
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
|
||||
parser_policy.add_argument('-p', '--pretrained-policy-name-or-path', type=str,
|
||||
parser_policy.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
type=str,
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`."
|
||||
)
|
||||
),
|
||||
)
|
||||
parser_policy.add_argument(
|
||||
"overrides",
|
||||
|
|
|
@ -580,9 +580,7 @@ def main(
|
|||
|
||||
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
||||
try:
|
||||
pretrained_policy_path = Path(
|
||||
snapshot_download(pretrained_policy_name_or_path, revision=revision)
|
||||
)
|
||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
|
@ -644,7 +642,9 @@ if __name__ == "__main__":
|
|||
if args.pretrained_policy_name_or_path is None:
|
||||
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
|
||||
else:
|
||||
pretrained_policy_path = get_pretrained_policy_path(args.pretrained_policy_name_or_path, revision=args.revision)
|
||||
pretrained_policy_path = get_pretrained_policy_path(
|
||||
args.pretrained_policy_name_or_path, revision=args.revision
|
||||
)
|
||||
|
||||
main(
|
||||
pretrained_policy_path=pretrained_policy_path,
|
||||
|
|
Loading…
Reference in New Issue