Rename record replay; Remove run_policy; Add koch.yaml; Move to degree

This commit is contained in:
Remi Cadene 2024-07-23 11:31:12 +02:00
parent b0f4f4f9a1
commit 875c1fbb2a
5 changed files with 222 additions and 246 deletions

View File

@ -98,6 +98,15 @@ MODEL_CONTROL_TABLE = {
"xm540-w270": X_SERIES_CONTROL_TABLE, "xm540-w270": X_SERIES_CONTROL_TABLE,
} }
MODEL_RESOLUTION = {
"x_series": 4096,
"xl330-m077": 4096,
"xl330-m288": 4096,
"xl430-w250": 4096,
"xm430-w350": 4096,
"xm540-w270": 4096,
}
NUM_READ_RETRY = 10 NUM_READ_RETRY = 10
@ -233,6 +242,7 @@ class DynamixelMotorsBus:
port: str, port: str,
motors: dict[str, tuple[int, str]], motors: dict[str, tuple[int, str]],
extra_model_control_table: dict[str, list[tuple]] | None = None, extra_model_control_table: dict[str, list[tuple]] | None = None,
extra_model_resolution: dict[str, int] | None = None,
): ):
self.port = port self.port = port
self.motors = motors self.motors = motors
@ -241,6 +251,10 @@ class DynamixelMotorsBus:
if extra_model_control_table: if extra_model_control_table:
self.model_ctrl_table.update(extra_model_control_table) self.model_ctrl_table.update(extra_model_control_table)
self.model_resolution = deepcopy(MODEL_RESOLUTION)
if extra_model_resolution:
self.model_resolution.update(extra_model_resolution)
self.port_handler = None self.port_handler = None
self.packet_handler = None self.packet_handler = None
self.calibration = None self.calibration = None
@ -281,34 +295,78 @@ class DynamixelMotorsBus:
self.calibration = calibration self.calibration = calibration
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
if not self.calibration: """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 centered degree range [-180.0, 180.0[
return values
Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
when given a goal position that is + or - their resolution. For instance, dynamixel xl330-m077 have a resolution of 4096, and
at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
or anticlockwise by moving to 42638. The position in the original range is arbitrary and might change a lot between each motor.
To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
in the centered degree range [-180, 180[. This function first applies the pre-computed calibration to convert
from [0, 2**32[ to [-2048, 2048[, then divide by 2048.
"""
if motor_names is None: if motor_names is None:
motor_names = self.motor_names motor_names = self.motor_names
# Convert from unsigned int32 original range [0, 2**32[ to centered signed int32 range [-2**31, 2**31[
values = values.astype(np.int32)
for i, name in enumerate(motor_names): for i, name in enumerate(motor_names):
homing_offset, drive_mode = self.calibration[name] homing_offset, drive_mode = self.calibration[name]
if values[i] is not None: # Update direction of rotation of the motor to match between leader and follower. In fact, the motor of the leader for a given joint
# can be assembled in an opposite direction in term of rotation than the motor of the follower on the same joint.
if drive_mode: if drive_mode:
values[i] *= -1 values[i] *= -1
# Convert from range [-2**31, 2**31[ to centered resolution range [-resolution, resolution[ (e.g. [-2048, 2048[)
values[i] += homing_offset values[i] += homing_offset
# Convert from range [-resolution, resolution[ to the universal float32 centered degree range [-180, 180[
values = values.astype(np.float32)
for i, name in enumerate(motor_names):
_, model = self.motors[name]
resolution = self.model_resolution[model]
values[i] = values[i] / (resolution // 2) * 180
if (values < -180).any() or (values >= 180).any():
raise ValueError(
f"At least one of the motor has a joint value outside of its centered degree range of [-180, 180[."
'This "jump of range" can be caused by a hardware issue, or you might have unexpectedly completed a full rotation of the motor '
"during manipulation or transportation of your robot. Try to recalibrate all motors by setting a different "
"`calibration_path` during the instatiation of your robot. "
f"The values and motors: {values} {motor_names}"
)
return values return values
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
if not self.calibration:
return values
if motor_names is None: if motor_names is None:
motor_names = self.motor_names motor_names = self.motor_names
if (values < -180).any() or (values >= 180).any():
raise ValueError(
f"At least one of the motor has a joint value outside of its centered degree range of [-180, 180[. "
f"The values and motors: {values} {motor_names}"
)
# Convert from the universal float32 centered degree range [-180, 180[ to centered resolution range [-resolution, resolution[
for i, name in enumerate(motor_names):
_, model = self.motors[name]
resolution = self.model_resolution[model]
values[i] = values[i] / 180 * (resolution // 2)
values = np.round(values).astype(np.int32)
# Convert from range [-resolution, resolution[ to centered signed int32 range [-2**31, 2**31[
for i, name in enumerate(motor_names): for i, name in enumerate(motor_names):
homing_offset, drive_mode = self.calibration[name] homing_offset, drive_mode = self.calibration[name]
if values[i] is not None:
values[i] -= homing_offset values[i] -= homing_offset
# Update direction of rotation of the motor that was matching between leader and follower to their original direction.
# In fact, the motor of the leader for a given joint can be assembled in an opposite direction in term of rotation
# than the motor of the follower on the same joint.
if drive_mode: if drive_mode:
values[i] *= -1 values[i] *= -1
@ -367,7 +425,7 @@ class DynamixelMotorsBus:
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED: if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
values = values.astype(np.int32) values = values.astype(np.int32)
if data_name in CALIBRATION_REQUIRED: if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
values = self.apply_calibration(values, motor_names) values = self.apply_calibration(values, motor_names)
# log the number of seconds it took to read the data from the motors # log the number of seconds it took to read the data from the motors
@ -406,7 +464,7 @@ class DynamixelMotorsBus:
motor_ids.append(motor_idx) motor_ids.append(motor_idx)
models.append(model) models.append(model)
if data_name in CALIBRATION_REQUIRED: if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
values = self.revert_calibration(values, motor_names) values = self.revert_calibration(values, motor_names)
values = values.tolist() values = values.tolist()

View File

@ -1,46 +1,7 @@
def make_robot(name): import hydra
if name == "koch": from omegaconf import DictConfig
# TODO(rcadene): Add configurable robot from command line and yaml config
# TODO(rcadene): Add example with and without cameras
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
from lerobot.common.robot_devices.robots.koch import KochRobot
robot = KochRobot(
leader_arms={
"main": DynamixelMotorsBus(
port="/dev/tty.usbmodem575E0031751",
motors={
# name: (index, model)
"shoulder_pan": (1, "xl330-m077"),
"shoulder_lift": (2, "xl330-m077"),
"elbow_flex": (3, "xl330-m077"),
"wrist_flex": (4, "xl330-m077"),
"wrist_roll": (5, "xl330-m077"),
"gripper": (6, "xl330-m077"),
},
),
},
follower_arms={
"main": DynamixelMotorsBus(
port="/dev/tty.usbmodem575E0032081",
motors={
# name: (index, model)
"shoulder_pan": (1, "xl430-w250"),
"shoulder_lift": (2, "xl430-w250"),
"elbow_flex": (3, "xl330-m288"),
"wrist_flex": (4, "xl330-m288"),
"wrist_roll": (5, "xl330-m288"),
"gripper": (6, "xl330-m288"),
},
),
},
cameras={
"laptop": OpenCVCamera(0, fps=30, width=640, height=480),
"phone": OpenCVCamera(1, fps=30, width=640, height=480),
},
)
else:
raise ValueError(f"Robot '{name}' not found.")
def make_robot(cfg: DictConfig):
robot = hydra.utils.instantiate(cfg)
return robot return robot

View File

@ -29,9 +29,12 @@ URL_90_DEGREE_POSITION = {
# Calibration logic # Calibration logic
######################################################################## ########################################################################
# In range ]-2048, 2048[
TARGET_HORIZONTAL_POSITION = np.array([0, -1024, 1024, 0, -1024, 0]) TARGET_HORIZONTAL_POSITION = np.array([0, -1024, 1024, 0, -1024, 0])
TARGET_90_DEGREE_POSITION = np.array([1024, 0, 0, 1024, 0, -1024]) TARGET_90_DEGREE_POSITION = np.array([1024, 0, 0, 1024, 0, -1024])
GRIPPER_OPEN = np.array([-400])
# In range ]-180, 180[
GRIPPER_OPEN = np.array([-35.156])
def apply_homing_offset(values: np.array, homing_offset: np.array) -> np.array: def apply_homing_offset(values: np.array, homing_offset: np.array) -> np.array:
@ -500,11 +503,7 @@ class KochRobot:
obs_dict = {} obs_dict = {}
obs_dict["observation.state"] = torch.from_numpy(state) obs_dict["observation.state"] = torch.from_numpy(state)
for name in self.cameras: for name in self.cameras:
# Convert to pytorch format: channel first and float32 in [0,1] obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
img = torch.from_numpy(images[name])
img = img.type(torch.float32) / 255
img = img.permute(2, 0, 1).contiguous()
obs_dict[f"observation.images.{name}"] = img
return obs_dict return obs_dict
def send_action(self, action: torch.Tensor): def send_action(self, action: torch.Tensor):

View File

@ -0,0 +1,38 @@
_target_: lerobot.common.robot_devices.robots.koch.KochRobot
leader_arms:
main:
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
port: /dev/tty.usbmodem575E0031751
motors:
# name: (index, model)
shoulder_pan: [1, "xl330-m077"]
shoulder_lift: [2, "xl330-m077"]
elbow_flex: [3, "xl330-m077"]
wrist_flex: [4, "xl330-m077"]
wrist_roll: [5, "xl330-m077"]
gripper: [6, "xl330-m077"]
follower_arms:
main:
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
port: /dev/tty.usbmodem575E0032081
motors:
# name: (index, model)
shoulder_pan: [1, "xl430-w250"]
shoulder_lift: [2, "xl430-w250"]
elbow_flex: [3, "xl330-m288"]
wrist_flex: [4, "xl330-m288"]
wrist_roll: [5, "xl330-m288"]
gripper: [6, "xl330-m288"]
cameras:
laptop:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 1
fps: 30
width: 640
height: 480
phone:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 2
fps: 30
width: 640
height: 480

View File

@ -107,7 +107,6 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import calculate_episode_data_index from lerobot.common.datasets.utils import calculate_episode_data_index
from lerobot.common.datasets.video_utils import encode_video_frames from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.cameras.utils import convert_torch_image_to_cv2
from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
@ -217,8 +216,10 @@ def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | Non
break break
def record_dataset( def record(
robot: Robot, robot: Robot,
policy: torch.nn.Module | None = None,
hydra_cfg: DictConfig | None = None,
fps: int | None = None, fps: int | None = None,
root="data", root="data",
repo_id="lerobot/debug", repo_id="lerobot/debug",
@ -295,6 +296,21 @@ def record_dataset(
listener = keyboard.Listener(on_press=on_press) listener = keyboard.Listener(on_press=on_press)
listener.start() listener.start()
# Load policy if any
if policy is not None:
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
policy.eval()
policy.to(device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
# override fps using policy fps
fps = hydra_cfg.env.fps
# Execute a few seconds without recording data, to give times # Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing. # to the robot devices to connect and start synchronizing.
timestamp = 0 timestamp = 0
@ -308,7 +324,10 @@ def record_dataset(
now = time.perf_counter() now = time.perf_counter()
if policy is None:
observation, action = robot.teleop_step(record_data=True) observation, action = robot.teleop_step(record_data=True)
else:
observation = robot.capture_observation()
if not is_headless(): if not is_headless():
image_keys = [key for key in observation if "image" in key] image_keys = [key for key in observation if "image" in key]
@ -339,7 +358,11 @@ def record_dataset(
start_time = time.perf_counter() start_time = time.perf_counter()
while timestamp < episode_time_s: while timestamp < episode_time_s:
now = time.perf_counter() now = time.perf_counter()
if policy is None:
observation, action = robot.teleop_step(record_data=True) observation, action = robot.teleop_step(record_data=True)
else:
observation = robot.capture_observation()
image_keys = [key for key in observation if "image" in key] image_keys = [key for key in observation if "image" in key]
not_image_keys = [key for key in observation if "image" not in key] not_image_keys = [key for key in observation if "image" not in key]
@ -362,6 +385,33 @@ def record_dataset(
ep_dict[key] = [] ep_dict[key] = []
ep_dict[key].append(observation[key]) ep_dict[key].append(observation[key])
if policy is not None:
with (
torch.inference_mode(),
torch.autocast(device_type=device.type)
if device.type == "cuda" and hydra_cfg.use_amp
else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
if device.type == "mps":
for name in observation:
observation[name] = observation[name].to(device)
action = policy.select_action(observation)
# remove batch dimension
action = action.squeeze(0)
action = action.to("cpu")
robot.send_action(action)
action = {"action": action}
for key in action: for key in action:
if key not in ep_dict: if key not in ep_dict:
ep_dict[key] = [] ep_dict[key] = []
@ -534,7 +584,7 @@ def record_dataset(
return lerobot_dataset return lerobot_dataset
def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"): def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
# TODO(rcadene): Add option to record logs # TODO(rcadene): Add option to record logs
local_dir = Path(root) / repo_id local_dir = Path(root) / repo_id
if not local_dir.exists(): if not local_dir.exists():
@ -564,151 +614,6 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
log_control_info(robot, dt_s, fps=fps) log_control_info(robot, dt_s, fps=fps)
def run_policy(
robot: Robot,
policy: torch.nn.Module,
hydra_cfg: DictConfig,
warmup_time_s: float = 4,
run_time_s: float | None = None,
reset_time_s: float = 15,
):
# TODO(rcadene): Add option to record eval dataset and logs
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
policy.eval()
policy.to(device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
fps = hydra_cfg.env.fps
if not robot.is_connected:
robot.connect()
if is_headless():
logging.info(
"Headless environment detected. Display cameras on screen and keyboard inputs will not be available."
)
# Allow to reset environment or exit early
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
reset_environment = False
exit_reset = False
# Only import pynput if not in a headless environment
if not is_headless():
from pynput import keyboard
def on_press(key):
nonlocal reset_environment, exit_reset
try:
if key == keyboard.Key.right and not reset_environment:
print("Right arrow key pressed. Suspend robot control to reset environment...")
reset_environment = True
elif key == keyboard.Key.right and reset_environment:
print("Right arrow key pressed. Enable robot control and exit reset environment...")
exit_reset = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
# Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing.
timestamp = 0
start_time = time.perf_counter()
is_warmup_print = False
while timestamp < warmup_time_s:
if not is_warmup_print:
logging.info("Warming up (no data recording)")
os.system('say "Warmup" &')
is_warmup_print = True
now = time.perf_counter()
observation = robot.capture_observation()
if not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
cv2.waitKey(1)
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_time
start_time = time.perf_counter()
while True:
now = time.perf_counter()
observation = robot.capture_observation()
if not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, convert_torch_image_to_cv2(observation[key]))
cv2.waitKey(1)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type)
if device.type == "cuda" and hydra_cfg.use_amp
else nullcontext(),
):
# add batch dimension to 1
for name in observation:
observation[name] = observation[name].unsqueeze(0)
if device.type == "mps":
for name in observation:
observation[name] = observation[name].to(device)
action = policy.select_action(observation)
# remove batch dimension
action = action.squeeze(0)
robot.send_action(action.to("cpu"))
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s, fps=fps)
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
break
if reset_environment:
# Start resetting env while the executor are finishing
logging.info("Reset the environment")
os.system('say "Reset the environment" &')
# Wait if necessary
timestamp = 0
start_time = time.perf_counter()
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
while timestamp < reset_time_s:
time.sleep(1)
timestamp = time.perf_counter() - start_time
pbar.update(1)
if exit_reset:
exit_reset = False
break
reset_environment = False
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="mode", required=True) subparsers = parser.add_subparsers(dest="mode", required=True)
@ -716,10 +621,15 @@ if __name__ == "__main__":
# Set common options for all the subparsers # Set common options for all the subparsers
base_parser = argparse.ArgumentParser(add_help=False) base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument( base_parser.add_argument(
"--robot", "--robot-path",
type=str, type=str,
default="koch", default="lerobot/configs/robot/koch.yaml",
help="Name of the robot provided to the `make_robot(name)` factory function.", help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
base_parser.add_argument(
"robot_overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
) )
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser]) parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
@ -727,7 +637,7 @@ if __name__ == "__main__":
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
) )
parser_record = subparsers.add_parser("record_dataset", parents=[base_parser]) parser_record = subparsers.add_parser("record", parents=[base_parser])
parser_record.add_argument( parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
) )
@ -786,8 +696,22 @@ if __name__ == "__main__":
default=0, default=0,
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.", help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
) )
parser_record.add_argument(
"-p",
"--pretrained-policy-name-or-path",
type=str,
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`."
),
)
parser_record.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser]) parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument( parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
) )
@ -805,41 +729,37 @@ if __name__ == "__main__":
) )
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.") parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
parser_policy.add_argument(
"-p",
"--pretrained-policy-name-or-path",
type=str,
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`."
),
)
parser_policy.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
args = parser.parse_args() args = parser.parse_args()
init_logging() init_logging()
control_mode = args.mode control_mode = args.mode
robot_name = args.robot robot_path = args.robot_path
robot_overrides = args.robot_overrides
kwargs = vars(args) kwargs = vars(args)
del kwargs["mode"] del kwargs["mode"]
del kwargs["robot"] del kwargs["robot_path"]
del kwargs["robot_overrides"]
robot_cfg = init_hydra_config(robot_path, robot_overrides)
robot = make_robot(robot_cfg)
robot = make_robot(robot_name)
if control_mode == "teleoperate": if control_mode == "teleoperate":
teleoperate(robot, **kwargs) teleoperate(robot, **kwargs)
elif control_mode == "record_dataset":
record_dataset(robot, **kwargs)
elif control_mode == "replay_episode":
replay_episode(robot, **kwargs)
elif control_mode == "run_policy": elif control_mode == "record":
pretrained_policy_path = get_pretrained_policy_path(args.pretrained_policy_name_or_path) pretrained_policy_name_or_path = args.pretrained_policy_name_or_path
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", args.overrides) overrides = args.overrides
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path) del kwargs["pretrained_policy_name_or_path"]
run_policy(robot, policy, hydra_cfg) del kwargs["overrides"]
policy_cfg = None
if pretrained_policy_name_or_path is not None:
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
policy_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", overrides)
policy = make_policy(hydra_cfg=policy_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
record(robot, policy, policy_cfg, **kwargs)
elif control_mode == "replay":
replay(robot, **kwargs)