Change config logic in:

- gym_manipulator
- find_joint_limits
- end_effector_utils
This commit is contained in:
Michel Aractingi 2025-03-25 14:24:46 +01:00
parent 6b18e4f3cf
commit 2c39504109
10 changed files with 252 additions and 269 deletions

View File

@ -40,13 +40,18 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
if "images" not in key: if "images" not in key:
continue continue
for imgkey, img in imgs.items():
# TODO(aliberts, rcadene): use transforms.ToTensor()? # TODO(aliberts, rcadene): use transforms.ToTensor()?
if not torch.is_tensor(img):
img = torch.from_numpy(img) img = torch.from_numpy(img)
if img.ndim == 3:
img = img.unsqueeze(0)
# sanity check that images are channel last # sanity check that images are channel last
_, h, w, c = img.shape _, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" assert c < h and c < w, (
f"expect channel last images, but instead got {img.shape=}"
)
# sanity check that images are uint8 # sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"

View File

@ -87,6 +87,8 @@ class RecordControlConfig(ControlConfig):
play_sounds: bool = True play_sounds: bool = True
# Resume recording on an existing dataset. # Resume recording on an existing dataset.
resume: bool = False resume: bool = False
# Reset follower arms to an initial configuration.
reset_follower_arms: bool = True
def __post_init__(self): def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one. # HACK: We parse again the cli args here to get the pretrained path if there was one.

View File

@ -221,7 +221,7 @@ def record_episode(
events=events, events=events,
policy=policy, policy=policy,
fps=fps, fps=fps,
record_delta_actions=record_delta_actions, # record_delta_actions=record_delta_actions,
teleoperate=policy is None, teleoperate=policy is None,
single_task=single_task, single_task=single_task,
) )
@ -267,8 +267,8 @@ def control_loop(
if teleoperate: if teleoperate:
observation, action = robot.teleop_step(record_data=True) observation, action = robot.teleop_step(record_data=True)
if record_delta_actions: # if record_delta_actions:
action["action"] = action["action"] - current_joint_positions # action["action"] = action["action"] - current_joint_positions
else: else:
observation = robot.capture_observation() observation = robot.capture_observation()

View File

@ -443,7 +443,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field( leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: { default_factory=lambda: {
"main": FeetechMotorsBusConfig( "main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091", port="/dev/tty.usbmodem58760433331",
motors={ motors={
# name: (index, model) # name: (index, model)
"shoulder_pan": [1, "sts3215"], "shoulder_pan": [1, "sts3215"],
@ -460,7 +460,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
follower_arms: dict[str, MotorsBusConfig] = field( follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: { default_factory=lambda: {
"main": FeetechMotorsBusConfig( "main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891", port="/dev/tty.usbmodem58760431631",
motors={ motors={
# name: (index, model) # name: (index, model)
"shoulder_pan": [1, "sts3215"], "shoulder_pan": [1, "sts3215"],

View File

@ -475,12 +475,12 @@ class ManipulatorRobot:
goal_pos = leader_pos[name] goal_pos = leader_pos[name]
# If specified, clip the goal positions within predefined bounds specified in the config of the robot # If specified, clip the goal positions within predefined bounds specified in the config of the robot
if self.config.joint_position_relative_bounds is not None: # if self.config.joint_position_relative_bounds is not None:
goal_pos = torch.clamp( # goal_pos = torch.clamp(
goal_pos, # goal_pos,
self.config.joint_position_relative_bounds["min"], # self.config.joint_position_relative_bounds["min"],
self.config.joint_position_relative_bounds["max"], # self.config.joint_position_relative_bounds["max"],
) # )
# Cap goal position when too far away from present position. # Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower. # Slower fps expected due to reading from the follower.
@ -604,12 +604,12 @@ class ManipulatorRobot:
from_idx = to_idx from_idx = to_idx
# If specified, clip the goal positions within predefined bounds specified in the config of the robot # If specified, clip the goal positions within predefined bounds specified in the config of the robot
if self.config.joint_position_relative_bounds is not None: # if self.config.joint_position_relative_bounds is not None:
goal_pos = torch.clamp( # goal_pos = torch.clamp(
goal_pos, # goal_pos,
self.config.joint_position_relative_bounds["min"], # self.config.joint_position_relative_bounds["min"],
self.config.joint_position_relative_bounds["max"], # self.config.joint_position_relative_bounds["max"],
) # )
# Cap goal position when too far away from present position. # Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower. # Slower fps expected due to reading from the follower.

View File

@ -366,8 +366,8 @@ def replay(
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
action = actions[idx]["action"] action = actions[idx]["action"]
if replay_delta_actions: # if replay_delta_actions:
action = action + current_joint_positions # action = action + current_joint_positions
robot.send_action(action) robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t dt_s = time.perf_counter() - start_episode_t
@ -394,7 +394,6 @@ def control_robot(cfg: ControlPipelineConfig):
replay(robot, cfg.control) replay(robot, cfg.control)
elif isinstance(cfg.control, RemoteRobotConfig): elif isinstance(cfg.control, RemoteRobotConfig):
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
run_lekiwi(cfg.robot) run_lekiwi(cfg.robot)
if robot.is_connected: if robot.is_connected:

View File

@ -1,14 +1,13 @@
import argparse from lerobot.common.robot_devices.utils import busy_wait
from lerobot.scripts.server.kinematics import RobotKinematics
import logging import logging
import time import time
import numpy as np
import torch import torch
import numpy as np
from lerobot.common.robot_devices.robots.factory import make_robot import argparse
from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.common.utils.utils import init_hydra_config from lerobot.scripts.server.gym_manipulator import make_robot_env, HILSerlRobotEnvConfig
from lerobot.scripts.server.kinematics import RobotKinematics from lerobot.common.robot_devices.robots.configs import RobotConfig
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -677,15 +676,6 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
# Close the environment # Close the environment
env.close() env.close()
def make_robot_from_config(config_path, overrides=None):
"""Helper function to create a robot from a config file."""
if overrides is None:
overrides = []
robot_cfg = init_hydra_config(config_path, overrides)
return make_robot(robot_cfg)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test end-effector control") parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument( parser.add_argument(
@ -703,21 +693,22 @@ if __name__ == "__main__":
help="Control mode to use", help="Control mode to use",
) )
parser.add_argument( parser.add_argument(
"--task", "--robot-type",
type=str, type=str,
default="Robot manipulation task", default="so100",
help="Description of the task being performed", help="Robot type (so100, koch, aloha, etc.)",
) )
parser.add_argument( parser.add_argument(
"--push-to-hub", "--config-path",
default=True, type=str,
type=bool, default=None,
help="Push the dataset to Hugging Face Hub", help="Path to the config file in json format",
) )
# Add the rest of your existing arguments
args = parser.parse_args() args = parser.parse_args()
robot = make_robot_from_config("lerobot/configs/robot/so100.yaml", []) robot_config = RobotConfig.get_choice_class(args.robot_type)(mock=False)
robot = make_robot_from_config(robot_config)
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
@ -743,12 +734,12 @@ if __name__ == "__main__":
elif args.mode in ["keyboard_gym", "gamepad_gym"]: elif args.mode in ["keyboard_gym", "gamepad_gym"]:
# Gym environment control modes # Gym environment control modes
from lerobot.scripts.server.gym_manipulator import make_robot_env cfg = HILSerlRobotEnvConfig()
if args.config_path is not None:
cfg = HILSerlRobotEnvConfig.from_json(args.config_path)
cfg = init_hydra_config("lerobot/configs/env/so100_real.yaml", []) env = make_robot_env(cfg, robot)
cfg.env.wrapper.ee_action_space_params.use_gamepad = False teleoperate_gym_env(env, controller, fps=args.fps)
env = make_robot_env(robot, None, cfg)
teleoperate_gym_env(env, controller)
elif args.mode == "leader": elif args.mode == "leader":
# Leader-follower modes don't use controllers # Leader-follower modes don't use controllers

View File

@ -5,10 +5,13 @@ import cv2
import numpy as np import numpy as np
from lerobot.common.robot_devices.control_utils import is_headless from lerobot.common.robot_devices.control_utils import is_headless
from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.server.kinematics import RobotKinematics from lerobot.scripts.server.kinematics import RobotKinematics
from lerobot.configs import parser
from lerobot.common.robot_devices.robots.configs import RobotConfig
follower_port = "/dev/tty.usbmodem58760431631"
leader_port = "/dev/tty.usbmodem58760433331"
def find_joint_bounds( def find_joint_bounds(
robot, robot,
@ -78,20 +81,28 @@ def find_ee_bounds(
break break
def make_robot(robot_type="so100", mock=True):
"""
Create a robot instance using the appropriate robot config class.
Args:
robot_type: Robot type string (e.g., "so100", "koch", "aloha")
mock: Whether to use mock mode for hardware (default: True)
Returns:
Robot instance
"""
# Get the appropriate robot config class based on robot_type
robot_config = RobotConfig.get_choice_class(robot_type)(mock=mock)
robot_config.leader_arms["main"].port = leader_port
robot_config.follower_arms["main"].port = follower_port
return make_robot_from_config(robot_config)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() # Create argparse for script-specific arguments
parser.add_argument( parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument( parser.add_argument(
"--mode", "--mode",
type=str, type=str,
@ -105,13 +116,29 @@ if __name__ == "__main__":
default=30, default=30,
help="Time step to use for control.", help="Time step to use for control.",
) )
args = parser.parse_args() parser.add_argument(
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) "--robot-type",
type=str,
default="so100",
help="Robot type (so100, koch, aloha, etc.)",
)
parser.add_argument(
"--mock",
type=int,
default=1,
help="Use mock mode for hardware simulation",
)
# Only parse known args, leaving robot config args for Hydra if used
args, _ = parser.parse_known_args()
# Create robot with the appropriate config
robot = make_robot(args.robot_type, args.mock)
robot = make_robot(robot_cfg)
if args.mode == "joint": if args.mode == "joint":
find_joint_bounds(robot, args.control_time_s) find_joint_bounds(robot, args.control_time_s)
elif args.mode == "ee": elif args.mode == "ee":
find_ee_bounds(robot, args.control_time_s) find_ee_bounds(robot, args.control_time_s)
if robot.is_connected: if robot.is_connected:
robot.disconnect() robot.disconnect()

View File

@ -1,7 +1,8 @@
import argparse
import logging import logging
import sys import sys
import time import time
import sys
from threading import Lock from threading import Lock
from typing import Annotated, Any, Dict, Tuple from typing import Annotated, Any, Dict, Tuple
@ -9,6 +10,9 @@ import gymnasium as gym
import numpy as np import numpy as np
import torch import torch
import torchvision.transforms.functional as F # noqa: N812 import torchvision.transforms.functional as F # noqa: N812
import json
from dataclasses import dataclass
from lerobot.common.envs.utils import preprocess_observation from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.robot_devices.control_utils import ( from lerobot.common.robot_devices.control_utils import (
@ -16,12 +20,69 @@ from lerobot.common.robot_devices.control_utils import (
is_headless, is_headless,
reset_follower_position, reset_follower_position,
) )
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config, log_say from typing import Optional
from lerobot.common.utils.utils import log_say
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.scripts.server.kinematics import RobotKinematics from lerobot.scripts.server.kinematics import RobotKinematics
from lerobot.configs import parser
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@dataclass
class EEActionSpaceConfig:
"""Configuration parameters for end-effector action space."""
x_step_size: float
y_step_size: float
z_step_size: float
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
use_gamepad: bool = False
@dataclass
class EnvWrapperConfig:
"""Configuration for environment wrappers."""
display_cameras: bool = False
delta_action: float = 0.1
use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False
add_ee_pose_to_observation: bool = False
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
resize_size: Optional[Tuple[int, int]] = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Optional[Any] = None
reset_time_s: float = 5.0
joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None
reward_classifier_pretrained_path: Optional[str] = None
reward_classifier_config_file: Optional[str] = None
@dataclass
class HILSerlRobotEnvConfig:
"""Configuration for the HILSerlRobotEnv environment."""
robot: RobotConfig
wrapper: EnvWrapperConfig
env_name: str = "real_robot"
fps: int = 10
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
task: str = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
@classmethod
def from_json(cls, json_path: str):
with open(json_path, "r") as f:
config = json.load(f)
return cls(**config)
class HILSerlRobotEnv(gym.Env): class HILSerlRobotEnv(gym.Env):
""" """
@ -49,7 +110,7 @@ class HILSerlRobotEnv(gym.Env):
The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
supports both relative (delta) adjustments and absolute joint positions for controlling the robot. supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
Args: cfg.
robot: The robot interface object used to connect and interact with the physical robot. robot: The robot interface object used to connect and interact with the physical robot.
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
joint positions are used. joint positions are used.
@ -77,14 +138,17 @@ class HILSerlRobotEnv(gym.Env):
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
# Retrieve the size of the joint position interval bound. # Retrieve the size of the joint position interval bound.
self.relative_bounds_size = (
( self.relative_bounds_size = None
self.robot.config.joint_position_relative_bounds["max"] # (
- self.robot.config.joint_position_relative_bounds["min"] # (
) # self.robot.config.joint_position_relative_bounds["max"]
if self.robot.config.joint_position_relative_bounds is not None # - self.robot.config.joint_position_relative_bounds["min"]
else None # )
) # if self.robot.config.joint_position_relative_bounds is not None
# else None
# )
self.robot.config.joint_position_relative_bounds = None
self.robot.config.max_relative_target = ( self.robot.config.max_relative_target = (
self.relative_bounds_size.float() if self.relative_bounds_size is not None else None self.relative_bounds_size.float() if self.relative_bounds_size is not None else None
@ -168,7 +232,7 @@ class HILSerlRobotEnv(gym.Env):
Reset the environment to its initial state. Reset the environment to its initial state.
This method resets the step counter and clears any episodic data. This method resets the step counter and clears any episodic data.
Args: cfg.
seed (Optional[int]): A seed for random number generation to ensure reproducibility. seed (Optional[int]): A seed for random number generation to ensure reproducibility.
options (Optional[dict]): Additional options to influence the reset behavior. options (Optional[dict]): Additional options to influence the reset behavior.
@ -203,7 +267,7 @@ class HILSerlRobotEnv(gym.Env):
- When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted - When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted
to relative change based on the current joint positions. to relative change based on the current joint positions.
Args: cfg.
action (tuple): A tuple with two elements: action (tuple): A tuple with two elements:
- policy_action (np.ndarray or torch.Tensor): The commanded joint positions. - policy_action (np.ndarray or torch.Tensor): The commanded joint positions.
- intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input. - intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input.
@ -258,7 +322,8 @@ class HILSerlRobotEnv(gym.Env):
if teleop_action.dim() == 1: if teleop_action.dim() == 1:
teleop_action = teleop_action.unsqueeze(0) teleop_action = teleop_action.unsqueeze(0)
# self.render() if self.display_cameras:
self.render()
self.current_step += 1 self.current_step += 1
@ -353,7 +418,7 @@ class RewardWrapper(gym.Wrapper):
""" """
Wrapper to add reward prediction to the environment, it use a trained classifer. Wrapper to add reward prediction to the environment, it use a trained classifer.
Args: cfg.
env: The environment to wrap env: The environment to wrap
reward_classifier: The reward classifier model reward_classifier: The reward classifier model
device: The device to run the model on device: The device to run the model on
@ -396,7 +461,7 @@ class JointMaskingActionSpace(gym.Wrapper):
""" """
Wrapper to mask out dimensions of the action space. Wrapper to mask out dimensions of the action space.
Args: cfg.
env: The environment to wrap env: The environment to wrap
mask: Binary mask array where 0 indicates dimensions to remove mask: Binary mask array where 0 indicates dimensions to remove
""" """
@ -428,7 +493,7 @@ class JointMaskingActionSpace(gym.Wrapper):
""" """
Convert masked action back to full action space. Convert masked action back to full action space.
Args: cfg.
action: Action in masked space. For Tuple spaces, the first element is masked. action: Action in masked space. For Tuple spaces, the first element is masked.
Returns: Returns:
@ -859,7 +924,7 @@ class GamepadControlWrapper(gym.Wrapper):
""" """
Initialize the gamepad controller wrapper. Initialize the gamepad controller wrapper.
Args: cfg.
env: The environment to wrap env: The environment to wrap
x_step_size: Base movement step size for X axis in meters x_step_size: Base movement step size for X axis in meters
y_step_size: Base movement step size for Y axis in meters y_step_size: Base movement step size for Y axis in meters
@ -937,7 +1002,7 @@ class GamepadControlWrapper(gym.Wrapper):
""" """
Step the environment, using gamepad input to override actions when active. Step the environment, using gamepad input to override actions when active.
Args: cfg.
action: Original action from agent action: Original action from agent
Returns: Returns:
@ -1029,25 +1094,19 @@ class ActionScaleWrapper(gym.ActionWrapper):
return action * self.scale_vector, is_intervention return action * self.scale_vector, is_intervention
def make_robot_env( def make_robot_env(cfg, robot) -> gym.vector.VectorEnv:
robot,
reward_classifier,
cfg,
n_envs: int = 1,
) -> gym.vector.VectorEnv:
""" """
Factory function to create a vectorized robot environment. Factory function to create a vectorized robot environment.
Args: cfg.
robot: Robot instance to control robot: Robot instance to control
reward_classifier: Classifier model for computing rewards reward_classifier: Classifier model for computing rewards
cfg: Configuration object containing environment parameters cfg: Configuration object containing environment parameters
n_envs: Number of environments to create in parallel. Defaults to 1.
Returns: Returns:
A vectorized gym environment with all the necessary wrappers applied. A vectorized gym environment with all the necessary wrappers applied.
""" """
if "maniskill" in cfg.env.name: if "maniskill" in cfg.env_name:
from lerobot.scripts.server.maniskill_manipulator import make_maniskill from lerobot.scripts.server.maniskill_manipulator import make_maniskill
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
@ -1056,69 +1115,69 @@ def make_robot_env(
n_envs=1, n_envs=1,
) )
return env return env
# Create base environment # Create base environment
env = HILSerlRobotEnv( env = HILSerlRobotEnv(
robot=robot, robot=robot,
display_cameras=cfg.env.wrapper.display_cameras, display_cameras=cfg.wrapper.display_cameras,
delta=cfg.env.wrapper.delta_action, delta=cfg.wrapper.delta_action,
use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions use_delta_action_space=cfg.wrapper.use_relative_joint_positions
and cfg.env.wrapper.ee_action_space_params is None, and cfg.wrapper.ee_action_space_params is None,
) )
# Add observation and image processing # Add observation and image processing
if cfg.env.wrapper.add_joint_velocity_to_observation: if cfg.wrapper.add_joint_velocity_to_observation:
env = AddJointVelocityToObservation(env=env, fps=cfg.fps) env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
if cfg.env.wrapper.add_ee_pose_to_observation: if cfg.wrapper.add_ee_pose_to_observation:
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds) env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
env = ConvertToLeRobotObservation(env=env, device=cfg.env.device) env = ConvertToLeRobotObservation(env=env, device=cfg.device)
if cfg.env.wrapper.crop_params_dict is not None: if cfg.wrapper.crop_params_dict is not None:
env = ImageCropResizeWrapper( env = ImageCropResizeWrapper(
env=env, env=env,
crop_params_dict=cfg.env.wrapper.crop_params_dict, crop_params_dict=cfg.wrapper.crop_params_dict,
resize_size=cfg.env.wrapper.resize_size, resize_size=cfg.wrapper.resize_size,
) )
# Add reward computation and control wrappers # Add reward computation and control wrappers
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.env.wrapper.ee_action_space_params is not None: if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params) env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
if ( if (
cfg.env.wrapper.ee_action_space_params is not None cfg.wrapper.ee_action_space_params is not None
and cfg.env.wrapper.ee_action_space_params.use_gamepad and cfg.wrapper.ee_action_space_params.use_gamepad
): ):
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params) # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
env = GamepadControlWrapper( env = GamepadControlWrapper(
env=env, env=env,
x_step_size=cfg.env.wrapper.ee_action_space_params.x_step_size, x_step_size=cfg.wrapper.ee_action_space_params.x_step_size,
y_step_size=cfg.env.wrapper.ee_action_space_params.y_step_size, y_step_size=cfg.wrapper.ee_action_space_params.y_step_size,
z_step_size=cfg.env.wrapper.ee_action_space_params.z_step_size, z_step_size=cfg.wrapper.ee_action_space_params.z_step_size,
) )
else: else:
env = KeyboardInterfaceWrapper(env=env) env = KeyboardInterfaceWrapper(env=env)
env = ResetWrapper( env = ResetWrapper(
env=env, env=env,
reset_pose=cfg.env.wrapper.fixed_reset_joint_positions, reset_pose=cfg.wrapper.fixed_reset_joint_positions,
reset_time_s=cfg.env.wrapper.reset_time_s, reset_time_s=cfg.wrapper.reset_time_s,
) )
if ( if (
cfg.env.wrapper.ee_action_space_params is None cfg.wrapper.ee_action_space_params is None
and cfg.env.wrapper.joint_masking_action_space is not None and cfg.wrapper.joint_masking_action_space is not None
): ):
env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space) env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
env = BatchCompitableWrapper(env=env) env = BatchCompitableWrapper(env=env)
return env return env
def get_classifier(pretrained_path, config_path, device="mps"): def get_classifier(cfg):
if pretrained_path is None or config_path is None: if cfg.wrapper.reward_classifier_pretrained_path is None or cfg.wrapper.reward_classifier_config_file is None:
return None return None
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ( from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig, ClassifierConfig,
) )
@ -1126,8 +1185,6 @@ def get_classifier(pretrained_path, config_path, device="mps"):
Classifier, Classifier,
) )
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
model = Classifier(classifier_config) model = Classifier(classifier_config)
@ -1136,21 +1193,11 @@ def get_classifier(pretrained_path, config_path, device="mps"):
return model return model
def record_dataset( def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
env,
repo_id,
root=None,
num_episodes=1,
control_time_s=20,
fps=30,
push_to_hub=True,
task_description="",
policy=None,
):
""" """
Record a dataset of robot interactions using either a policy or teleop. Record a dataset of robot interactions using either a policy or teleop.
Args: cfg.
env: The environment to record from env: The environment to record from
repo_id: Repository ID for dataset storage repo_id: Repository ID for dataset storage
root: Local root directory for dataset (optional) root: Local root directory for dataset (optional)
@ -1195,9 +1242,9 @@ def record_dataset(
# Create dataset # Create dataset
dataset = LeRobotDataset.create( dataset = LeRobotDataset.create(
repo_id, cfg.repo_id,
fps, cfg.fps,
root=root, root=cfg.dataset_root,
use_videos=True, use_videos=True,
image_writer_threads=4, image_writer_threads=4,
image_writer_processes=0, image_writer_processes=0,
@ -1206,17 +1253,17 @@ def record_dataset(
# Record episodes # Record episodes
episode_index = 0 episode_index = 0
while episode_index < num_episodes: while episode_index < cfg.record_num_episodes:
obs, _ = env.reset() obs, _ = env.reset()
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
log_say(f"Recording episode {episode_index}", play_sounds=True) log_say(f"Recording episode {episode_index}", play_sounds=True)
# Run episode steps # Run episode steps
while time.perf_counter() - start_episode_t < control_time_s: while time.perf_counter() - start_episode_t < cfg.wrapper.control_time_s:
start_loop_t = time.perf_counter() start_loop_t = time.perf_counter()
# Get action from policy if available # Get action from policy if available
if policy is not None: if cfg.pretrained_policy_name_or_path is not None:
action = policy.select_action(obs) action = policy.select_action(obs)
# Step environment # Step environment
@ -1240,9 +1287,9 @@ def record_dataset(
dataset.add_frame(frame) dataset.add_frame(frame)
# Maintain consistent timing # Maintain consistent timing
if fps: if cfg.fps:
dt_s = time.perf_counter() - start_loop_t dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s) busy_wait(1 / cfg.fps - dt_s)
if terminated or truncated: if terminated or truncated:
break break
@ -1253,13 +1300,13 @@ def record_dataset(
logging.info(f"Re-recording episode {episode_index}") logging.info(f"Re-recording episode {episode_index}")
continue continue
dataset.save_episode(task_description) dataset.save_episode(cfg.task)
episode_index += 1 episode_index += 1
# Finalize dataset # Finalize dataset
dataset.consolidate(run_compute_stats=True) dataset.consolidate(run_compute_stats=True)
if push_to_hub: if cfg.push_to_hub:
dataset.push_to_hub(repo_id) dataset.push_to_hub(cfg.repo_id)
def replay_episode(env, repo_id, root=None, episode=0): def replay_episode(env, repo_id, root=None, episode=0):
@ -1282,134 +1329,44 @@ def replay_episode(env, repo_id, root=None, episode=0):
busy_wait(1 / 10 - dt_s) busy_wait(1 / 10 - dt_s)
if __name__ == "__main__": @parser.wrap()
parser = argparse.ArgumentParser() def main(cfg: HILSerlRobotEnvConfig):
parser.add_argument("--fps", type=int, default=30, help="control frequency")
parser.add_argument(
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
),
)
parser.add_argument(
"--display-cameras",
help=("Whether to display the camera feed while the rollout is happening"),
)
parser.add_argument(
"--reward-classifier-pretrained-path",
type=str,
default=None,
help="Path to the pretrained classifier weights.",
)
parser.add_argument(
"--reward-classifier-config-file",
type=str,
default=None,
help="Path to a yaml config file that is necessary to build the reward classifier model.",
)
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
parser.add_argument(
"--env-overrides",
type=str,
default=None,
help="Overrides for the env yaml file",
)
parser.add_argument(
"--control-time-s",
type=float,
default=20,
help="Maximum episode length in seconds",
)
parser.add_argument(
"--reset-follower-pos",
type=int,
default=1,
help="Reset follower between episodes",
)
parser.add_argument(
"--replay-repo-id",
type=str,
default=None,
help="Repo ID of the episode to replay",
)
parser.add_argument("--dataset-root", type=str, default=None, help="Root of the dataset to replay")
parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay")
parser.add_argument(
"--record-repo-id",
type=str,
default=None,
help="Repo ID of the dataset to record",
)
parser.add_argument(
"--record-num-episodes",
type=int,
default=1,
help="Number of episodes to record",
)
parser.add_argument(
"--record-episode-task",
type=str,
default="",
help="Single line description of the task to record",
)
args = parser.parse_args() robot = make_robot_from_config(cfg.robot)
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) reward_classifier = None #get_classifier(
robot = make_robot(robot_cfg) # cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file
# )
reward_classifier = get_classifier(
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
)
user_relative_joint_positions = True user_relative_joint_positions = True
cfg = init_hydra_config(args.env_path, args.env_overrides) env = make_robot_env(cfg, robot)
env = make_robot_env(
robot,
reward_classifier,
cfg, # .wrapper,
)
if args.record_repo_id is not None: if cfg.mode == "record":
policy = None policy = None
if args.pretrained_policy_name_or_path is not None: if cfg.pretrained_policy_name_or_path is not None:
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
policy = SACPolicy.from_pretrained(args.pretrained_policy_name_or_path) policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
policy.to(cfg.device) policy.to(cfg.device)
policy.eval() policy.eval()
record_dataset( record_dataset(
env, env,
args.record_repo_id, cfg.repo_id,
root=args.dataset_root, root=cfg.dataset_root,
num_episodes=args.record_num_episodes, num_episodes=cfg.num_episodes,
fps=args.fps, fps=cfg.fps,
task_description=args.record_episode_task, task_description=cfg.task,
policy=policy, policy=policy,
) )
exit() exit()
if args.replay_repo_id is not None: if cfg.mode == "replay":
replay_episode( replay_episode(
env, env,
args.replay_repo_id, cfg.replay_repo_id,
root=args.dataset_root, root=cfg.dataset_root,
episode=args.replay_episode, episode=cfg.replay_episode,
) )
exit() exit()
@ -1442,7 +1399,10 @@ if __name__ == "__main__":
num_episode += 1 num_episode += 1
dt_s = time.perf_counter() - start_loop_s dt_s = time.perf_counter() - start_loop_s
busy_wait(1 / args.fps - dt_s) busy_wait(1 / cfg.fps - dt_s)
logging.info(f"Success after 20 steps {sucesses}") logging.info(f"Success after 20 steps {sucesses}")
logging.info(f"success rate {sum(sucesses) / len(sucesses)}") logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
if __name__ == "__main__":
main()

View File

@ -59,7 +59,6 @@ dependencies = [
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'", "huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
"imageio[ffmpeg]>=2.34.0", "imageio[ffmpeg]>=2.34.0",
"jsonlines>=4.0.0", "jsonlines>=4.0.0",
"mani-skill>=3.0.0b18",
"numba>=0.59.0", "numba>=0.59.0",
"omegaconf>=2.3.0", "omegaconf>=2.3.0",
"opencv-python>=4.9.0", "opencv-python>=4.9.0",