Style
This commit is contained in:
parent
47aac0dff7
commit
8a7aa50e97
|
@ -1,10 +1,9 @@
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import time
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass, replace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
import time
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -23,10 +22,13 @@ def find_camera_indices(raise_when_empty=False, max_index_search_range=60):
|
||||||
camera_ids.append(camera_idx)
|
camera_ids.append(camera_idx)
|
||||||
|
|
||||||
if raise_when_empty and len(camera_ids) == 0:
|
if raise_when_empty and len(camera_ids) == 0:
|
||||||
raise OSError("Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, or your camera driver, or make sure your camera is compatible with opencv2.")
|
raise OSError(
|
||||||
|
"Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, or your camera driver, or make sure your camera is compatible with opencv2."
|
||||||
|
)
|
||||||
|
|
||||||
return camera_ids
|
return camera_ids
|
||||||
|
|
||||||
|
|
||||||
def benchmark_cameras(cameras, out_dir=None, save_images=False, num_warmup_frames=4):
|
def benchmark_cameras(cameras, out_dir=None, save_images=False, num_warmup_frames=4):
|
||||||
if out_dir:
|
if out_dir:
|
||||||
out_dir = Path(out_dir)
|
out_dir = Path(out_dir)
|
||||||
|
@ -50,7 +52,7 @@ def benchmark_cameras(cameras, out_dir=None, save_images=False, num_warmup_frame
|
||||||
print(f"Write to {image_path}")
|
print(f"Write to {image_path}")
|
||||||
save_color_image(color_image, image_path, write_shape=True)
|
save_color_image(color_image, image_path, write_shape=True)
|
||||||
|
|
||||||
dt_s = (time.time() - now)
|
dt_s = time.time() - now
|
||||||
dt_ms = dt_s * 1000
|
dt_ms = dt_s * 1000
|
||||||
freq = 1 / dt_s
|
freq = 1 / dt_s
|
||||||
print(f"Latency (ms): {dt_ms:.2f}\tFrequency: {freq:.2f}")
|
print(f"Latency (ms): {dt_ms:.2f}\tFrequency: {freq:.2f}")
|
||||||
|
@ -73,14 +75,14 @@ class OpenCVCameraConfig:
|
||||||
OpenCVCameraConfig(30, 1280, 720)
|
OpenCVCameraConfig(30, 1280, 720)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
fps: int | None = None
|
fps: int | None = None
|
||||||
width: int | None = None
|
width: int | None = None
|
||||||
height: int | None = None
|
height: int | None = None
|
||||||
color: str = "rgb"
|
color: str = "rgb"
|
||||||
|
|
||||||
|
|
||||||
|
class OpenCVCamera:
|
||||||
class OpenCVCamera():
|
|
||||||
# TODO(rcadene): improve dosctring
|
# TODO(rcadene): improve dosctring
|
||||||
"""
|
"""
|
||||||
https://docs.opencv.org/4.x/d0/da7/videoio_overview.html
|
https://docs.opencv.org/4.x/d0/da7/videoio_overview.html
|
||||||
|
@ -93,6 +95,7 @@ class OpenCVCamera():
|
||||||
color_image = camera.capture_image()
|
color_image = camera.capture_image()
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, camera_index: int, config: OpenCVCameraConfig | None = None, **kwargs):
|
def __init__(self, camera_index: int, config: OpenCVCameraConfig | None = None, **kwargs):
|
||||||
if config is None:
|
if config is None:
|
||||||
config = OpenCVCameraConfig()
|
config = OpenCVCameraConfig()
|
||||||
|
@ -109,7 +112,9 @@ class OpenCVCamera():
|
||||||
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {self.color} is provided.")
|
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {self.color} is provided.")
|
||||||
|
|
||||||
if self.camera_index is None:
|
if self.camera_index is None:
|
||||||
raise ValueError(f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {camera_index} is provided instead.")
|
raise ValueError(
|
||||||
|
f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {camera_index} is provided instead."
|
||||||
|
)
|
||||||
|
|
||||||
self.camera = None
|
self.camera = None
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
|
@ -134,7 +139,9 @@ class OpenCVCamera():
|
||||||
if not is_camera_open:
|
if not is_camera_open:
|
||||||
# Verify that the provided `camera_index` is valid before printing the traceback
|
# Verify that the provided `camera_index` is valid before printing the traceback
|
||||||
if self.camera_index not in find_camera_indices():
|
if self.camera_index not in find_camera_indices():
|
||||||
raise ValueError(f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {self.camera_index} is provided instead.")
|
raise ValueError(
|
||||||
|
f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {self.camera_index} is provided instead."
|
||||||
|
)
|
||||||
|
|
||||||
raise OSError(f"Can't access camera {self.camera_index}.")
|
raise OSError(f"Can't access camera {self.camera_index}.")
|
||||||
|
|
||||||
|
@ -155,11 +162,17 @@ class OpenCVCamera():
|
||||||
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||||
|
|
||||||
if self.fps and self.fps != actual_fps:
|
if self.fps and self.fps != actual_fps:
|
||||||
raise OSError(f"Can't set {self.fps=} for camera {self.camera_index}. Actual value is {actual_fps}.")
|
raise OSError(
|
||||||
|
f"Can't set {self.fps=} for camera {self.camera_index}. Actual value is {actual_fps}."
|
||||||
|
)
|
||||||
if self.width and self.width != actual_width:
|
if self.width and self.width != actual_width:
|
||||||
raise OSError(f"Can't set {self.width=} for camera {self.camera_index}. Actual value is {actual_width}.")
|
raise OSError(
|
||||||
|
f"Can't set {self.width=} for camera {self.camera_index}. Actual value is {actual_width}."
|
||||||
|
)
|
||||||
if self.height and self.height != actual_height:
|
if self.height and self.height != actual_height:
|
||||||
raise OSError(f"Can't set {self.height=} for camera {self.camera_index}. Actual value is {actual_height}.")
|
raise OSError(
|
||||||
|
f"Can't set {self.height=} for camera {self.camera_index}. Actual value is {actual_height}."
|
||||||
|
)
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
self.t.start()
|
self.t.start()
|
||||||
|
@ -172,10 +185,7 @@ class OpenCVCamera():
|
||||||
if not ret:
|
if not ret:
|
||||||
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
|
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
|
||||||
|
|
||||||
if temporary_color is None:
|
requested_color = self.color if temporary_color is None else temporary_color
|
||||||
requested_color = self.color
|
|
||||||
else:
|
|
||||||
requested_color = temporary_color
|
|
||||||
|
|
||||||
if requested_color not in ["rgb", "bgr"]:
|
if requested_color not in ["rgb", "bgr"]:
|
||||||
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {requested_color} is provided.")
|
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {requested_color} is provided.")
|
||||||
|
@ -215,6 +225,7 @@ def save_images_config(config: OpenCVCameraConfig, out_dir: Path):
|
||||||
out_dir = out_dir.parent / f"{out_dir.name}_{config.width}x{config.height}_{config.fps}"
|
out_dir = out_dir.parent / f"{out_dir.name}_{config.width}x{config.height}_{config.fps}"
|
||||||
benchmark_cameras(cameras, out_dir, save_images=True)
|
benchmark_cameras(cameras, out_dir, save_images=True)
|
||||||
|
|
||||||
|
|
||||||
def benchmark_config(config: OpenCVCameraConfig, camera_ids: list[int]):
|
def benchmark_config(config: OpenCVCameraConfig, camera_ids: list[int]):
|
||||||
cameras = [OpenCVCamera(idx, config) for idx in camera_ids]
|
cameras = [OpenCVCamera(idx, config) for idx in camera_ids]
|
||||||
benchmark_cameras(cameras)
|
benchmark_cameras(cameras)
|
||||||
|
@ -222,7 +233,7 @@ def benchmark_config(config: OpenCVCameraConfig, camera_ids: list[int]):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--mode", type=str, choices=["save_images", 'benchmark'], default="save_images")
|
parser.add_argument("--mode", type=str, choices=["save_images", "benchmark"], default="save_images")
|
||||||
parser.add_argument("--camera-ids", type=int, nargs="*", default=[16, 4, 22, 10])
|
parser.add_argument("--camera-ids", type=int, nargs="*", default=[16, 4, 22, 10])
|
||||||
parser.add_argument("--fps", type=int, default=30)
|
parser.add_argument("--fps", type=int, default=30)
|
||||||
parser.add_argument("--width", type=str, default=640)
|
parser.add_argument("--width", type=str, default=640)
|
||||||
|
@ -242,8 +253,3 @@ if __name__ == "__main__":
|
||||||
benchmark_config(config, args.camera_ids)
|
benchmark_config(config, args.camera_ids)
|
||||||
else:
|
else:
|
||||||
raise ValueError(args.mode)
|
raise ValueError(args.mode)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,13 @@
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
|
||||||
import cv2
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def write_shape_on_image_inplace(image):
|
def write_shape_on_image_inplace(image):
|
||||||
height, width = image.shape[:2]
|
height, width = image.shape[:2]
|
||||||
text = f'Width: {width} Height: {height}'
|
text = f"Width: {width} Height: {height}"
|
||||||
|
|
||||||
# Define the font, scale, color, and thickness
|
# Define the font, scale, color, and thickness
|
||||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
|
|
@ -1,10 +1,18 @@
|
||||||
from copy import deepcopy
|
|
||||||
import enum
|
import enum
|
||||||
from typing import Union
|
from copy import deepcopy
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from dynamixel_sdk import PacketHandler, PortHandler, COMM_SUCCESS, GroupSyncRead, GroupSyncWrite
|
import numpy as np
|
||||||
from dynamixel_sdk import DXL_HIBYTE, DXL_HIWORD, DXL_LOBYTE, DXL_LOWORD
|
from dynamixel_sdk import (
|
||||||
|
COMM_SUCCESS,
|
||||||
|
DXL_HIBYTE,
|
||||||
|
DXL_HIWORD,
|
||||||
|
DXL_LOBYTE,
|
||||||
|
DXL_LOWORD,
|
||||||
|
GroupSyncRead,
|
||||||
|
GroupSyncWrite,
|
||||||
|
PacketHandler,
|
||||||
|
PortHandler,
|
||||||
|
)
|
||||||
|
|
||||||
PROTOCOL_VERSION = 2.0
|
PROTOCOL_VERSION = 2.0
|
||||||
BAUD_RATE = 1_000_000
|
BAUD_RATE = 1_000_000
|
||||||
|
@ -69,7 +77,7 @@ X_SERIES_CONTROL_TABLE = {
|
||||||
"Velocity_Trajectory": (136, 4),
|
"Velocity_Trajectory": (136, 4),
|
||||||
"Position_Trajectory": (140, 4),
|
"Position_Trajectory": (140, 4),
|
||||||
"Present_Input_Voltage": (144, 2),
|
"Present_Input_Voltage": (144, 2),
|
||||||
"Present_Temperature": (146, 1)
|
"Present_Temperature": (146, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
|
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||||
|
@ -86,6 +94,7 @@ MODEL_CONTROL_TABLE = {
|
||||||
"xm540-w270": X_SERIES_CONTROL_TABLE,
|
"xm540-w270": X_SERIES_CONTROL_TABLE,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def uint32_to_int32(values: np.ndarray):
|
def uint32_to_int32(values: np.ndarray):
|
||||||
"""
|
"""
|
||||||
Convert an unsigned 32-bit integer array to a signed 32-bit integer array.
|
Convert an unsigned 32-bit integer array to a signed 32-bit integer array.
|
||||||
|
@ -95,6 +104,7 @@ def uint32_to_int32(values: np.ndarray):
|
||||||
values[i] = values[i] - 4294967296
|
values[i] = values[i] - 4294967296
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
def int32_to_uint32(values: np.ndarray):
|
def int32_to_uint32(values: np.ndarray):
|
||||||
"""
|
"""
|
||||||
Convert a signed 32-bit integer array to an unsigned 32-bit integer array.
|
Convert a signed 32-bit integer array to an unsigned 32-bit integer array.
|
||||||
|
@ -104,12 +114,14 @@ def int32_to_uint32(values: np.ndarray):
|
||||||
values[i] = values[i] + 4294967296
|
values[i] = values[i] + 4294967296
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
def motor_position_to_angle(position: np.ndarray) -> np.ndarray:
|
def motor_position_to_angle(position: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Convert from motor position in [-2048, 2048] to radian in [-pi, pi]
|
Convert from motor position in [-2048, 2048] to radian in [-pi, pi]
|
||||||
"""
|
"""
|
||||||
return (position / 2048) * 3.14
|
return (position / 2048) * 3.14
|
||||||
|
|
||||||
|
|
||||||
def motor_angle_to_position(angle: np.ndarray) -> np.ndarray:
|
def motor_angle_to_position(angle: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Convert from radian in [-pi, pi] to motor position in [-2048, 2048]
|
Convert from radian in [-pi, pi] to motor position in [-2048, 2048]
|
||||||
|
@ -134,7 +146,7 @@ def motor_angle_to_position(angle: np.ndarray) -> np.ndarray:
|
||||||
|
|
||||||
|
|
||||||
def get_group_sync_key(data_name, motor_names):
|
def get_group_sync_key(data_name, motor_names):
|
||||||
group_key = f"{data_name}_" + "_".join([name for name in motor_names])
|
group_key = f"{data_name}_" + "_".join(motor_names)
|
||||||
return group_key
|
return group_key
|
||||||
|
|
||||||
|
|
||||||
|
@ -158,9 +170,12 @@ class DriveMode(enum.Enum):
|
||||||
|
|
||||||
|
|
||||||
class DynamixelMotorsBus:
|
class DynamixelMotorsBus:
|
||||||
|
def __init__(
|
||||||
def __init__(self, port: str, motors: dict[str, tuple[int, str]],
|
self,
|
||||||
extra_model_control_table: dict[str, list[tuple]] | None = None):
|
port: str,
|
||||||
|
motors: dict[str, tuple[int, str]],
|
||||||
|
extra_model_control_table: dict[str, list[tuple]] | None = None,
|
||||||
|
):
|
||||||
self.port = port
|
self.port = port
|
||||||
self.motors = motors
|
self.motors = motors
|
||||||
|
|
||||||
|
@ -307,9 +322,11 @@ class DynamixelMotorsBus:
|
||||||
|
|
||||||
init_group = data_name not in self.group_readers
|
init_group = data_name not in self.group_readers
|
||||||
if init_group:
|
if init_group:
|
||||||
self.group_writers[group_key] = GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
self.group_writers[group_key] = GroupSyncWrite(
|
||||||
|
self.port_handler, self.packet_handler, addr, bytes
|
||||||
|
)
|
||||||
|
|
||||||
for idx, value in zip(motor_ids, values):
|
for idx, value in zip(motor_ids, values, strict=False):
|
||||||
if bytes == 1:
|
if bytes == 1:
|
||||||
data = [
|
data = [
|
||||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||||
|
@ -329,7 +346,8 @@ class DynamixelMotorsBus:
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||||
f"{bytes} is provided instead.")
|
f"{bytes} is provided instead."
|
||||||
|
)
|
||||||
|
|
||||||
if init_group:
|
if init_group:
|
||||||
self.group_writers[group_key].addParam(idx, data)
|
self.group_writers[group_key].addParam(idx, data)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
|
|
||||||
class MotorsBus(Protocol):
|
class MotorsBus(Protocol):
|
||||||
def motor_names(self): ...
|
def motor_names(self): ...
|
||||||
def set_calibration(self): ...
|
def set_calibration(self): ...
|
||||||
|
|
|
@ -1,11 +1,8 @@
|
||||||
|
|
||||||
|
|
||||||
def make_robot(name):
|
def make_robot(name):
|
||||||
|
|
||||||
if name == "koch":
|
if name == "koch":
|
||||||
from lerobot.common.robot_devices.robots.koch import KochRobot
|
|
||||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
|
||||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||||
|
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||||
|
from lerobot.common.robot_devices.robots.koch import KochRobot
|
||||||
|
|
||||||
robot = KochRobot(
|
robot = KochRobot(
|
||||||
leader_arms={
|
leader_arms={
|
||||||
|
@ -38,7 +35,7 @@ def make_robot(name):
|
||||||
},
|
},
|
||||||
cameras={
|
cameras={
|
||||||
"main": OpenCVCamera(1, fps=30, width=640, height=480),
|
"main": OpenCVCamera(1, fps=30, width=640, height=480),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Robot '{name}' not found.")
|
raise ValueError(f"Robot '{name}' not found.")
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
import copy
|
import pickle
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pickle
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
|
||||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
|
||||||
from lerobot.common.robot_devices.motors.dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode, motor_position_to_angle
|
|
||||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
|
||||||
|
|
||||||
|
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||||
|
from lerobot.common.robot_devices.motors.dynamixel import (
|
||||||
|
DriveMode,
|
||||||
|
DynamixelMotorsBus,
|
||||||
|
OperatingMode,
|
||||||
|
TorqueMode,
|
||||||
|
)
|
||||||
|
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# Calibration logic
|
# Calibration logic
|
||||||
|
@ -22,6 +26,7 @@ TARGET_HORIZONTAL_POSITION = np.array([0, -1024, 1024, 0, -1024, 0])
|
||||||
TARGET_90_DEGREE_POSITION = np.array([1024, 0, 0, 1024, 0, -1024])
|
TARGET_90_DEGREE_POSITION = np.array([1024, 0, 0, 1024, 0, -1024])
|
||||||
GRIPPER_OPEN = np.array([-400])
|
GRIPPER_OPEN = np.array([-400])
|
||||||
|
|
||||||
|
|
||||||
def apply_homing_offset(values: np.array, homing_offset: np.array) -> np.array:
|
def apply_homing_offset(values: np.array, homing_offset: np.array) -> np.array:
|
||||||
for i in range(len(values)):
|
for i in range(len(values)):
|
||||||
if values[i] is not None:
|
if values[i] is not None:
|
||||||
|
@ -35,18 +40,21 @@ def apply_drive_mode(values: np.array, drive_mode: np.array) -> np.array:
|
||||||
values[i] = -values[i]
|
values[i] = -values[i]
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
def apply_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
|
def apply_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
|
||||||
values = apply_drive_mode(values, drive_mode)
|
values = apply_drive_mode(values, drive_mode)
|
||||||
values = apply_homing_offset(values, homing_offset)
|
values = apply_homing_offset(values, homing_offset)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
def revert_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
|
def revert_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
|
||||||
"""
|
"""
|
||||||
Transform working position into real position for the robot.
|
Transform working position into real position for the robot.
|
||||||
"""
|
"""
|
||||||
values = apply_homing_offset(values, np.array([
|
values = apply_homing_offset(
|
||||||
-homing_offset if homing_offset is not None else None for homing_offset in homing_offset
|
values,
|
||||||
]))
|
np.array([-homing_offset if homing_offset is not None else None for homing_offset in homing_offset]),
|
||||||
|
)
|
||||||
values = apply_drive_mode(values, drive_mode)
|
values = apply_drive_mode(values, drive_mode)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@ -73,15 +81,20 @@ def compute_corrections(positions: np.array, drive_mode: list[bool], target_posi
|
||||||
|
|
||||||
def compute_nearest_rounded_positions(positions: np.array) -> np.array:
|
def compute_nearest_rounded_positions(positions: np.array) -> np.array:
|
||||||
return np.array(
|
return np.array(
|
||||||
[round(positions[i] / 1024) * 1024 if positions[i] is not None else None for i in range(len(positions))])
|
[
|
||||||
|
round(positions[i] / 1024) * 1024 if positions[i] is not None else None
|
||||||
|
for i in range(len(positions))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_homing_offset(arm: DynamixelMotorsBus, drive_mode: list[bool], target_position: np.array) -> np.array:
|
def compute_homing_offset(
|
||||||
|
arm: DynamixelMotorsBus, drive_mode: list[bool], target_position: np.array
|
||||||
|
) -> np.array:
|
||||||
# Get the present positions of the servos
|
# Get the present positions of the servos
|
||||||
present_positions = apply_calibration(
|
present_positions = apply_calibration(
|
||||||
arm.read("Present_Position"),
|
arm.read("Present_Position"), np.array([0, 0, 0, 0, 0, 0]), drive_mode
|
||||||
np.array([0, 0, 0, 0, 0, 0]),
|
)
|
||||||
drive_mode)
|
|
||||||
|
|
||||||
nearest_positions = compute_nearest_rounded_positions(present_positions)
|
nearest_positions = compute_nearest_rounded_positions(present_positions)
|
||||||
correction = compute_corrections(nearest_positions, drive_mode, target_position)
|
correction = compute_corrections(nearest_positions, drive_mode, target_position)
|
||||||
|
@ -91,9 +104,8 @@ def compute_homing_offset(arm: DynamixelMotorsBus, drive_mode: list[bool], targe
|
||||||
def compute_drive_mode(arm: DynamixelMotorsBus, offset: np.array):
|
def compute_drive_mode(arm: DynamixelMotorsBus, offset: np.array):
|
||||||
# Get current positions
|
# Get current positions
|
||||||
present_positions = apply_calibration(
|
present_positions = apply_calibration(
|
||||||
arm.read("Present_Position"),
|
arm.read("Present_Position"), offset, np.array([False, False, False, False, False, False])
|
||||||
offset,
|
)
|
||||||
np.array([False, False, False, False, False, False]))
|
|
||||||
|
|
||||||
nearest_positions = compute_nearest_rounded_positions(present_positions)
|
nearest_positions = compute_nearest_rounded_positions(present_positions)
|
||||||
|
|
||||||
|
@ -131,7 +143,9 @@ def run_arm_calibration(arm: MotorsBus, name: str):
|
||||||
print(f"Please move the '{name}' arm to the horizontal position (gripper fully closed)")
|
print(f"Please move the '{name}' arm to the horizontal position (gripper fully closed)")
|
||||||
input("Press Enter to continue...")
|
input("Press Enter to continue...")
|
||||||
|
|
||||||
horizontal_homing_offset = compute_homing_offset(arm, [False, False, False, False, False, False], TARGET_HORIZONTAL_POSITION)
|
horizontal_homing_offset = compute_homing_offset(
|
||||||
|
arm, [False, False, False, False, False, False], TARGET_HORIZONTAL_POSITION
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(rcadene): document what position 2 mean
|
# TODO(rcadene): document what position 2 mean
|
||||||
print(f"Please move the '{name}' arm to the 90 degree position (gripper fully open)")
|
print(f"Please move the '{name}' arm to the 90 degree position (gripper fully open)")
|
||||||
|
@ -159,6 +173,7 @@ def run_arm_calibration(arm: MotorsBus, name: str):
|
||||||
# Alexander Koch robot arm
|
# Alexander Koch robot arm
|
||||||
########################################################################
|
########################################################################
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KochRobotConfig:
|
class KochRobotConfig:
|
||||||
"""
|
"""
|
||||||
|
@ -201,11 +216,10 @@ class KochRobotConfig:
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
cameras: dict[str, Camera] = field(
|
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
||||||
default_factory=lambda: {}
|
|
||||||
)
|
|
||||||
|
|
||||||
class KochRobot():
|
|
||||||
|
class KochRobot:
|
||||||
"""Tau Robotics: https://tau-robotics.com
|
"""Tau Robotics: https://tau-robotics.com
|
||||||
|
|
||||||
Example of usage:
|
Example of usage:
|
||||||
|
@ -214,7 +228,12 @@ class KochRobot():
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: KochRobotConfig | None = None, calibration_path: Path = ".cache/calibration/koch.pkl", **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: KochRobotConfig | None = None,
|
||||||
|
calibration_path: Path = ".cache/calibration/koch.pkl",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
if config is None:
|
if config is None:
|
||||||
config = KochRobotConfig()
|
config = KochRobotConfig()
|
||||||
# Overwrite config arguments using kwargs
|
# Overwrite config arguments using kwargs
|
||||||
|
@ -234,13 +253,13 @@ class KochRobot():
|
||||||
for name in self.leader_arms:
|
for name in self.leader_arms:
|
||||||
reset_arm(self.leader_arms[name])
|
reset_arm(self.leader_arms[name])
|
||||||
|
|
||||||
with open(self.calibration_path, 'rb') as f:
|
with open(self.calibration_path, "rb") as f:
|
||||||
calibration = pickle.load(f)
|
calibration = pickle.load(f)
|
||||||
else:
|
else:
|
||||||
calibration = self.run_calibration()
|
calibration = self.run_calibration()
|
||||||
|
|
||||||
self.calibration_path.parent.mkdir(parents=True, exist_ok=True)
|
self.calibration_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(self.calibration_path, 'wb') as f:
|
with open(self.calibration_path, "wb") as f:
|
||||||
pickle.dump(calibration, f)
|
pickle.dump(calibration, f)
|
||||||
|
|
||||||
for name in self.follower_arms:
|
for name in self.follower_arms:
|
||||||
|
@ -275,7 +294,9 @@ class KochRobot():
|
||||||
|
|
||||||
return calibration
|
return calibration
|
||||||
|
|
||||||
def teleop_step(self, record_data=False) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
def teleop_step(
|
||||||
|
self, record_data=False
|
||||||
|
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||||
# Prepare to assign the positions of the leader to the follower
|
# Prepare to assign the positions of the leader to the follower
|
||||||
leader_pos = {}
|
leader_pos = {}
|
||||||
for name in self.leader_arms:
|
for name in self.leader_arms:
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
|
|
||||||
class Robot(Protocol):
|
class Robot(Protocol):
|
||||||
def init_teleop(self): ...
|
def init_teleop(self): ...
|
||||||
def run_calibration(self): ...
|
def run_calibration(self): ...
|
||||||
|
|
|
@ -62,20 +62,22 @@ python lerobot/scripts/control_robot.py run_policy \
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from contextlib import nullcontext
|
import concurrent.futures
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
from omegaconf import DictConfig
|
|
||||||
import torch
|
import torch
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import compute_stats
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||||
from lerobot.common.datasets.utils import calculate_episode_data_index, load_hf_dataset
|
from lerobot.common.datasets.utils import calculate_episode_data_index
|
||||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
|
@ -83,14 +85,12 @@ from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
|
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
|
||||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||||
from lerobot.scripts.push_dataset_to_hub import save_meta_data
|
from lerobot.scripts.push_dataset_to_hub import save_meta_data
|
||||||
from lerobot.scripts.robot_controls.record_dataset import record_dataset
|
|
||||||
import concurrent.futures
|
|
||||||
|
|
||||||
|
|
||||||
########################################################################################
|
########################################################################################
|
||||||
# Utilities
|
# Utilities
|
||||||
########################################################################################
|
########################################################################################
|
||||||
|
|
||||||
|
|
||||||
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
|
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
|
||||||
img = Image.fromarray(img_tensor.numpy())
|
img = Image.fromarray(img_tensor.numpy())
|
||||||
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
||||||
|
@ -106,15 +106,18 @@ def busy_wait(seconds):
|
||||||
while time.perf_counter() < end_time:
|
while time.perf_counter() < end_time:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def none_or_int(value):
|
def none_or_int(value):
|
||||||
if value == 'None':
|
if value == "None":
|
||||||
return None
|
return None
|
||||||
return int(value)
|
return int(value)
|
||||||
|
|
||||||
|
|
||||||
########################################################################################
|
########################################################################################
|
||||||
# Control modes
|
# Control modes
|
||||||
########################################################################################
|
########################################################################################
|
||||||
|
|
||||||
|
|
||||||
def teleoperate(robot: Robot, fps: int | None = None):
|
def teleoperate(robot: Robot, fps: int | None = None):
|
||||||
robot.init_teleop()
|
robot.init_teleop()
|
||||||
|
|
||||||
|
@ -123,14 +126,24 @@ def teleoperate(robot: Robot, fps: int | None = None):
|
||||||
robot.teleop_step()
|
robot.teleop_step()
|
||||||
|
|
||||||
if fps is not None:
|
if fps is not None:
|
||||||
dt_s = (time.perf_counter() - now)
|
dt_s = time.perf_counter() - now
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
dt_s = (time.perf_counter() - now)
|
dt_s = time.perf_counter() - now
|
||||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||||
|
|
||||||
|
|
||||||
def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="lerobot/debug", warmup_time_s=2, episode_time_s=10, num_episodes=50, video=True, run_compute_stats=True):
|
def record_dataset(
|
||||||
|
robot: Robot,
|
||||||
|
fps: int | None = None,
|
||||||
|
root="data",
|
||||||
|
repo_id="lerobot/debug",
|
||||||
|
warmup_time_s=2,
|
||||||
|
episode_time_s=10,
|
||||||
|
num_episodes=50,
|
||||||
|
video=True,
|
||||||
|
run_compute_stats=True,
|
||||||
|
):
|
||||||
if not video:
|
if not video:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -143,7 +156,6 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
||||||
videos_dir = local_dir / "videos"
|
videos_dir = local_dir / "videos"
|
||||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
is_warmup_print = False
|
is_warmup_print = False
|
||||||
|
@ -154,7 +166,6 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
||||||
# Using `with` ensures the program exists smoothly if an execption is raised.
|
# Using `with` ensures the program exists smoothly if an execption is raised.
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
for episode_index in range(num_episodes):
|
for episode_index in range(num_episodes):
|
||||||
|
|
||||||
ep_dict = {}
|
ep_dict = {}
|
||||||
frame_index = 0
|
frame_index = 0
|
||||||
|
|
||||||
|
@ -169,10 +180,10 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
||||||
timestamp = time.perf_counter() - start_time
|
timestamp = time.perf_counter() - start_time
|
||||||
|
|
||||||
if timestamp < warmup_time_s:
|
if timestamp < warmup_time_s:
|
||||||
dt_s = (time.perf_counter() - now)
|
dt_s = time.perf_counter() - now
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
dt_s = (time.perf_counter() - now)
|
dt_s = time.perf_counter() - now
|
||||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f} (Warmup)")
|
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f} (Warmup)")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -199,10 +210,10 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
||||||
|
|
||||||
frame_index += 1
|
frame_index += 1
|
||||||
|
|
||||||
dt_s = (time.perf_counter() - now)
|
dt_s = time.perf_counter() - now
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
dt_s = (time.perf_counter() - now)
|
dt_s = time.perf_counter() - now
|
||||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||||
|
|
||||||
if timestamp > episode_time_s - warmup_time_s:
|
if timestamp > episode_time_s - warmup_time_s:
|
||||||
|
@ -269,10 +280,7 @@ def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="l
|
||||||
info=info,
|
info=info,
|
||||||
videos_dir=videos_dir,
|
videos_dir=videos_dir,
|
||||||
)
|
)
|
||||||
if run_compute_stats:
|
stats = compute_stats(lerobot_dataset) if run_compute_stats else {}
|
||||||
stats = compute_stats(lerobot_dataset)
|
|
||||||
else:
|
|
||||||
stats = {}
|
|
||||||
|
|
||||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||||
|
@ -303,10 +311,10 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
|
||||||
action = items[idx]["action"]
|
action = items[idx]["action"]
|
||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
|
|
||||||
dt_s = (time.perf_counter() - now)
|
dt_s = time.perf_counter() - now
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
dt_s = (time.perf_counter() - now)
|
dt_s = time.perf_counter() - now
|
||||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -327,15 +335,18 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
|
||||||
|
|
||||||
observation = robot.capture_observation()
|
observation = robot.capture_observation()
|
||||||
|
|
||||||
with torch.inference_mode(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
with (
|
||||||
|
torch.inference_mode(),
|
||||||
|
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
|
||||||
|
):
|
||||||
action = policy.select_action(observation)
|
action = policy.select_action(observation)
|
||||||
|
|
||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
|
|
||||||
dt_s = (time.perf_counter() - now)
|
dt_s = time.perf_counter() - now
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
dt_s = (time.perf_counter() - now)
|
dt_s = time.perf_counter() - now
|
||||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -345,32 +356,46 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Set common options for all the subparsers
|
# Set common options for all the subparsers
|
||||||
base_parser = argparse.ArgumentParser(add_help=False)
|
base_parser = argparse.ArgumentParser(add_help=False)
|
||||||
base_parser.add_argument("--robot", type=str, default="koch", help="Name of the robot provided to the `make_robot(name)` factory function.")
|
base_parser.add_argument(
|
||||||
|
"--robot",
|
||||||
|
type=str,
|
||||||
|
default="koch",
|
||||||
|
help="Name of the robot provided to the `make_robot(name)` factory function.",
|
||||||
|
)
|
||||||
|
|
||||||
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
|
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
|
||||||
parser_teleop.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
|
parser_teleop.add_argument(
|
||||||
|
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||||
|
)
|
||||||
|
|
||||||
parser_record = subparsers.add_parser("record_dataset", parents=[base_parser])
|
parser_record = subparsers.add_parser("record_dataset", parents=[base_parser])
|
||||||
parser_record.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
|
parser_record.add_argument(
|
||||||
parser_record.add_argument('--root', type=Path, default="data", help='')
|
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||||
parser_record.add_argument('--repo-id', type=str, default="lerobot/test", help='')
|
)
|
||||||
parser_record.add_argument('--warmup-time-s', type=int, default=2, help='')
|
parser_record.add_argument("--root", type=Path, default="data", help="")
|
||||||
parser_record.add_argument('--episode-time-s', type=int, default=10, help='')
|
parser_record.add_argument("--repo-id", type=str, default="lerobot/test", help="")
|
||||||
parser_record.add_argument('--num-episodes', type=int, default=50, help='')
|
parser_record.add_argument("--warmup-time-s", type=int, default=2, help="")
|
||||||
parser_record.add_argument('--run-compute-stats', type=int, default=1, help='')
|
parser_record.add_argument("--episode-time-s", type=int, default=10, help="")
|
||||||
|
parser_record.add_argument("--num-episodes", type=int, default=50, help="")
|
||||||
|
parser_record.add_argument("--run-compute-stats", type=int, default=1, help="")
|
||||||
|
|
||||||
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
|
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
|
||||||
parser_replay.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
|
parser_replay.add_argument(
|
||||||
parser_replay.add_argument('--root', type=Path, default="data", help='')
|
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||||
parser_replay.add_argument('--repo-id', type=str, default="lerobot/test", help='')
|
)
|
||||||
parser_replay.add_argument('--episode', type=int, default=0, help='')
|
parser_replay.add_argument("--root", type=Path, default="data", help="")
|
||||||
|
parser_replay.add_argument("--repo-id", type=str, default="lerobot/test", help="")
|
||||||
|
parser_replay.add_argument("--episode", type=int, default=0, help="")
|
||||||
|
|
||||||
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
|
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
|
||||||
parser_policy.add_argument('-p', '--pretrained-policy-name-or-path', type=str,
|
parser_policy.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--pretrained-policy-name-or-path",
|
||||||
|
type=str,
|
||||||
help=(
|
help=(
|
||||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||||
"saved using `Policy.save_pretrained`."
|
"saved using `Policy.save_pretrained`."
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
parser_policy.add_argument(
|
parser_policy.add_argument(
|
||||||
"overrides",
|
"overrides",
|
||||||
|
|
|
@ -580,9 +580,7 @@ def main(
|
||||||
|
|
||||||
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
||||||
try:
|
try:
|
||||||
pretrained_policy_path = Path(
|
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
||||||
snapshot_download(pretrained_policy_name_or_path, revision=revision)
|
|
||||||
)
|
|
||||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||||
if isinstance(e, HFValidationError):
|
if isinstance(e, HFValidationError):
|
||||||
error_message = (
|
error_message = (
|
||||||
|
@ -644,7 +642,9 @@ if __name__ == "__main__":
|
||||||
if args.pretrained_policy_name_or_path is None:
|
if args.pretrained_policy_name_or_path is None:
|
||||||
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
|
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
|
||||||
else:
|
else:
|
||||||
pretrained_policy_path = get_pretrained_policy_path(args.pretrained_policy_name_or_path, revision=args.revision)
|
pretrained_policy_path = get_pretrained_policy_path(
|
||||||
|
args.pretrained_policy_name_or_path, revision=args.revision
|
||||||
|
)
|
||||||
|
|
||||||
main(
|
main(
|
||||||
pretrained_policy_path=pretrained_policy_path,
|
pretrained_policy_path=pretrained_policy_path,
|
||||||
|
|
Loading…
Reference in New Issue