Change config logic in:

- gym_manipulator
- find_joint_limits
- end_effector_utils
This commit is contained in:
Michel Aractingi 2025-03-25 14:24:46 +01:00 committed by AdilZouitine
parent ee25fd8afe
commit b7b6d8102f
10 changed files with 252 additions and 269 deletions

View File

@ -40,13 +40,18 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
if "images" not in key:
continue
for imgkey, img in imgs.items():
# TODO(aliberts, rcadene): use transforms.ToTensor()?
# TODO(aliberts, rcadene): use transforms.ToTensor()?
if not torch.is_tensor(img):
img = torch.from_numpy(img)
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
if img.ndim == 3:
img = img.unsqueeze(0)
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, (
f"expect channel last images, but instead got {img.shape=}"
)
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"

View File

@ -87,6 +87,8 @@ class RecordControlConfig(ControlConfig):
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
# Reset follower arms to an initial configuration.
reset_follower_arms: bool = True
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.

View File

@ -221,7 +221,7 @@ def record_episode(
events=events,
policy=policy,
fps=fps,
record_delta_actions=record_delta_actions,
# record_delta_actions=record_delta_actions,
teleoperate=policy is None,
single_task=single_task,
)
@ -267,8 +267,8 @@ def control_loop(
if teleoperate:
observation, action = robot.teleop_step(record_data=True)
if record_delta_actions:
action["action"] = action["action"] - current_joint_positions
# if record_delta_actions:
# action["action"] = action["action"] - current_joint_positions
else:
observation = robot.capture_observation()

View File

@ -443,7 +443,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091",
port="/dev/tty.usbmodem58760433331",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
@ -460,7 +460,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891",
port="/dev/tty.usbmodem58760431631",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],

View File

@ -475,12 +475,12 @@ class ManipulatorRobot:
goal_pos = leader_pos[name]
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
if self.config.joint_position_relative_bounds is not None:
goal_pos = torch.clamp(
goal_pos,
self.config.joint_position_relative_bounds["min"],
self.config.joint_position_relative_bounds["max"],
)
# if self.config.joint_position_relative_bounds is not None:
# goal_pos = torch.clamp(
# goal_pos,
# self.config.joint_position_relative_bounds["min"],
# self.config.joint_position_relative_bounds["max"],
# )
# Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower.
@ -604,12 +604,12 @@ class ManipulatorRobot:
from_idx = to_idx
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
if self.config.joint_position_relative_bounds is not None:
goal_pos = torch.clamp(
goal_pos,
self.config.joint_position_relative_bounds["min"],
self.config.joint_position_relative_bounds["max"],
)
# if self.config.joint_position_relative_bounds is not None:
# goal_pos = torch.clamp(
# goal_pos,
# self.config.joint_position_relative_bounds["min"],
# self.config.joint_position_relative_bounds["max"],
# )
# Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower.

View File

@ -366,8 +366,8 @@ def replay(
start_episode_t = time.perf_counter()
action = actions[idx]["action"]
if replay_delta_actions:
action = action + current_joint_positions
# if replay_delta_actions:
# action = action + current_joint_positions
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
@ -394,7 +394,6 @@ def control_robot(cfg: ControlPipelineConfig):
replay(robot, cfg.control)
elif isinstance(cfg.control, RemoteRobotConfig):
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
run_lekiwi(cfg.robot)
if robot.is_connected:

View File

@ -1,14 +1,13 @@
import argparse
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.scripts.server.kinematics import RobotKinematics
import logging
import time
import numpy as np
import torch
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.server.kinematics import RobotKinematics
import numpy as np
import argparse
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import make_robot_env, HILSerlRobotEnvConfig
from lerobot.common.robot_devices.robots.configs import RobotConfig
logging.basicConfig(level=logging.INFO)
@ -677,15 +676,6 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
# Close the environment
env.close()
def make_robot_from_config(config_path, overrides=None):
"""Helper function to create a robot from a config file."""
if overrides is None:
overrides = []
robot_cfg = init_hydra_config(config_path, overrides)
return make_robot(robot_cfg)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument(
@ -703,21 +693,22 @@ if __name__ == "__main__":
help="Control mode to use",
)
parser.add_argument(
"--task",
"--robot-type",
type=str,
default="Robot manipulation task",
help="Description of the task being performed",
default="so100",
help="Robot type (so100, koch, aloha, etc.)",
)
parser.add_argument(
"--push-to-hub",
default=True,
type=bool,
help="Push the dataset to Hugging Face Hub",
"--config-path",
type=str,
default=None,
help="Path to the config file in json format",
)
# Add the rest of your existing arguments
args = parser.parse_args()
robot = make_robot_from_config("lerobot/configs/robot/so100.yaml", [])
robot_config = RobotConfig.get_choice_class(args.robot_type)(mock=False)
robot = make_robot_from_config(robot_config)
if not robot.is_connected:
robot.connect()
@ -743,12 +734,12 @@ if __name__ == "__main__":
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
# Gym environment control modes
from lerobot.scripts.server.gym_manipulator import make_robot_env
cfg = HILSerlRobotEnvConfig()
if args.config_path is not None:
cfg = HILSerlRobotEnvConfig.from_json(args.config_path)
cfg = init_hydra_config("lerobot/configs/env/so100_real.yaml", [])
cfg.env.wrapper.ee_action_space_params.use_gamepad = False
env = make_robot_env(robot, None, cfg)
teleoperate_gym_env(env, controller)
env = make_robot_env(cfg, robot)
teleoperate_gym_env(env, controller, fps=args.fps)
elif args.mode == "leader":
# Leader-follower modes don't use controllers

View File

@ -5,10 +5,13 @@ import cv2
import numpy as np
from lerobot.common.robot_devices.control_utils import is_headless
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.kinematics import RobotKinematics
from lerobot.configs import parser
from lerobot.common.robot_devices.robots.configs import RobotConfig
follower_port = "/dev/tty.usbmodem58760431631"
leader_port = "/dev/tty.usbmodem58760433331"
def find_joint_bounds(
robot,
@ -78,20 +81,28 @@ def find_ee_bounds(
break
def make_robot(robot_type="so100", mock=True):
"""
Create a robot instance using the appropriate robot config class.
Args:
robot_type: Robot type string (e.g., "so100", "koch", "aloha")
mock: Whether to use mock mode for hardware (default: True)
Returns:
Robot instance
"""
# Get the appropriate robot config class based on robot_type
robot_config = RobotConfig.get_choice_class(robot_type)(mock=mock)
robot_config.leader_arms["main"].port = leader_port
robot_config.follower_arms["main"].port = follower_port
return make_robot_from_config(robot_config)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
# Create argparse for script-specific arguments
parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict
parser.add_argument(
"--mode",
type=str,
@ -105,13 +116,29 @@ if __name__ == "__main__":
default=30,
help="Time step to use for control.",
)
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
robot = make_robot(robot_cfg)
parser.add_argument(
"--robot-type",
type=str,
default="so100",
help="Robot type (so100, koch, aloha, etc.)",
)
parser.add_argument(
"--mock",
type=int,
default=1,
help="Use mock mode for hardware simulation",
)
# Only parse known args, leaving robot config args for Hydra if used
args, _ = parser.parse_known_args()
# Create robot with the appropriate config
robot = make_robot(args.robot_type, args.mock)
if args.mode == "joint":
find_joint_bounds(robot, args.control_time_s)
elif args.mode == "ee":
find_ee_bounds(robot, args.control_time_s)
if robot.is_connected:
robot.disconnect()
robot.disconnect()

View File

@ -1,7 +1,8 @@
import argparse
import logging
import sys
import time
import sys
from threading import Lock
from typing import Annotated, Any, Dict, Tuple
@ -9,6 +10,9 @@ import gymnasium as gym
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
import json
from dataclasses import dataclass
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.robot_devices.control_utils import (
@ -16,12 +20,69 @@ from lerobot.common.robot_devices.control_utils import (
is_headless,
reset_follower_position,
)
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config, log_say
from typing import Optional
from lerobot.common.utils.utils import log_say
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.scripts.server.kinematics import RobotKinematics
from lerobot.configs import parser
logging.basicConfig(level=logging.INFO)
@dataclass
class EEActionSpaceConfig:
"""Configuration parameters for end-effector action space."""
x_step_size: float
y_step_size: float
z_step_size: float
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
use_gamepad: bool = False
@dataclass
class EnvWrapperConfig:
"""Configuration for environment wrappers."""
display_cameras: bool = False
delta_action: float = 0.1
use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False
add_ee_pose_to_observation: bool = False
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
resize_size: Optional[Tuple[int, int]] = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Optional[Any] = None
reset_time_s: float = 5.0
joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None
reward_classifier_pretrained_path: Optional[str] = None
reward_classifier_config_file: Optional[str] = None
@dataclass
class HILSerlRobotEnvConfig:
"""Configuration for the HILSerlRobotEnv environment."""
robot: RobotConfig
wrapper: EnvWrapperConfig
env_name: str = "real_robot"
fps: int = 10
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
task: str = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
@classmethod
def from_json(cls, json_path: str):
with open(json_path, "r") as f:
config = json.load(f)
return cls(**config)
class HILSerlRobotEnv(gym.Env):
"""
@ -49,7 +110,7 @@ class HILSerlRobotEnv(gym.Env):
The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
Args:
cfg.
robot: The robot interface object used to connect and interact with the physical robot.
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
joint positions are used.
@ -77,14 +138,17 @@ class HILSerlRobotEnv(gym.Env):
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
# Retrieve the size of the joint position interval bound.
self.relative_bounds_size = (
(
self.robot.config.joint_position_relative_bounds["max"]
- self.robot.config.joint_position_relative_bounds["min"]
)
if self.robot.config.joint_position_relative_bounds is not None
else None
)
self.relative_bounds_size = None
# (
# (
# self.robot.config.joint_position_relative_bounds["max"]
# - self.robot.config.joint_position_relative_bounds["min"]
# )
# if self.robot.config.joint_position_relative_bounds is not None
# else None
# )
self.robot.config.joint_position_relative_bounds = None
self.robot.config.max_relative_target = (
self.relative_bounds_size.float() if self.relative_bounds_size is not None else None
@ -168,7 +232,7 @@ class HILSerlRobotEnv(gym.Env):
Reset the environment to its initial state.
This method resets the step counter and clears any episodic data.
Args:
cfg.
seed (Optional[int]): A seed for random number generation to ensure reproducibility.
options (Optional[dict]): Additional options to influence the reset behavior.
@ -203,7 +267,7 @@ class HILSerlRobotEnv(gym.Env):
- When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted
to relative change based on the current joint positions.
Args:
cfg.
action (tuple): A tuple with two elements:
- policy_action (np.ndarray or torch.Tensor): The commanded joint positions.
- intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input.
@ -258,7 +322,8 @@ class HILSerlRobotEnv(gym.Env):
if teleop_action.dim() == 1:
teleop_action = teleop_action.unsqueeze(0)
# self.render()
if self.display_cameras:
self.render()
self.current_step += 1
@ -353,7 +418,7 @@ class RewardWrapper(gym.Wrapper):
"""
Wrapper to add reward prediction to the environment, it use a trained classifer.
Args:
cfg.
env: The environment to wrap
reward_classifier: The reward classifier model
device: The device to run the model on
@ -396,7 +461,7 @@ class JointMaskingActionSpace(gym.Wrapper):
"""
Wrapper to mask out dimensions of the action space.
Args:
cfg.
env: The environment to wrap
mask: Binary mask array where 0 indicates dimensions to remove
"""
@ -428,7 +493,7 @@ class JointMaskingActionSpace(gym.Wrapper):
"""
Convert masked action back to full action space.
Args:
cfg.
action: Action in masked space. For Tuple spaces, the first element is masked.
Returns:
@ -859,7 +924,7 @@ class GamepadControlWrapper(gym.Wrapper):
"""
Initialize the gamepad controller wrapper.
Args:
cfg.
env: The environment to wrap
x_step_size: Base movement step size for X axis in meters
y_step_size: Base movement step size for Y axis in meters
@ -937,7 +1002,7 @@ class GamepadControlWrapper(gym.Wrapper):
"""
Step the environment, using gamepad input to override actions when active.
Args:
cfg.
action: Original action from agent
Returns:
@ -1029,25 +1094,19 @@ class ActionScaleWrapper(gym.ActionWrapper):
return action * self.scale_vector, is_intervention
def make_robot_env(
robot,
reward_classifier,
cfg,
n_envs: int = 1,
) -> gym.vector.VectorEnv:
def make_robot_env(cfg, robot) -> gym.vector.VectorEnv:
"""
Factory function to create a vectorized robot environment.
Args:
cfg.
robot: Robot instance to control
reward_classifier: Classifier model for computing rewards
cfg: Configuration object containing environment parameters
n_envs: Number of environments to create in parallel. Defaults to 1.
Returns:
A vectorized gym environment with all the necessary wrappers applied.
"""
if "maniskill" in cfg.env.name:
if "maniskill" in cfg.env_name:
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
@ -1056,69 +1115,69 @@ def make_robot_env(
n_envs=1,
)
return env
# Create base environment
env = HILSerlRobotEnv(
robot=robot,
display_cameras=cfg.env.wrapper.display_cameras,
delta=cfg.env.wrapper.delta_action,
use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions
and cfg.env.wrapper.ee_action_space_params is None,
display_cameras=cfg.wrapper.display_cameras,
delta=cfg.wrapper.delta_action,
use_delta_action_space=cfg.wrapper.use_relative_joint_positions
and cfg.wrapper.ee_action_space_params is None,
)
# Add observation and image processing
if cfg.env.wrapper.add_joint_velocity_to_observation:
if cfg.wrapper.add_joint_velocity_to_observation:
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
if cfg.env.wrapper.add_ee_pose_to_observation:
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds)
if cfg.wrapper.add_ee_pose_to_observation:
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
env = ConvertToLeRobotObservation(env=env, device=cfg.env.device)
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
if cfg.env.wrapper.crop_params_dict is not None:
if cfg.wrapper.crop_params_dict is not None:
env = ImageCropResizeWrapper(
env=env,
crop_params_dict=cfg.env.wrapper.crop_params_dict,
resize_size=cfg.env.wrapper.resize_size,
crop_params_dict=cfg.wrapper.crop_params_dict,
resize_size=cfg.wrapper.resize_size,
)
# Add reward computation and control wrappers
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps)
if cfg.env.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
if (
cfg.env.wrapper.ee_action_space_params is not None
and cfg.env.wrapper.ee_action_space_params.use_gamepad
cfg.wrapper.ee_action_space_params is not None
and cfg.wrapper.ee_action_space_params.use_gamepad
):
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params)
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
env = GamepadControlWrapper(
env=env,
x_step_size=cfg.env.wrapper.ee_action_space_params.x_step_size,
y_step_size=cfg.env.wrapper.ee_action_space_params.y_step_size,
z_step_size=cfg.env.wrapper.ee_action_space_params.z_step_size,
x_step_size=cfg.wrapper.ee_action_space_params.x_step_size,
y_step_size=cfg.wrapper.ee_action_space_params.y_step_size,
z_step_size=cfg.wrapper.ee_action_space_params.z_step_size,
)
else:
env = KeyboardInterfaceWrapper(env=env)
env = ResetWrapper(
env=env,
reset_pose=cfg.env.wrapper.fixed_reset_joint_positions,
reset_time_s=cfg.env.wrapper.reset_time_s,
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
reset_time_s=cfg.wrapper.reset_time_s,
)
if (
cfg.env.wrapper.ee_action_space_params is None
and cfg.env.wrapper.joint_masking_action_space is not None
cfg.wrapper.ee_action_space_params is None
and cfg.wrapper.joint_masking_action_space is not None
):
env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space)
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
env = BatchCompitableWrapper(env=env)
return env
def get_classifier(pretrained_path, config_path, device="mps"):
if pretrained_path is None or config_path is None:
def get_classifier(cfg):
if cfg.wrapper.reward_classifier_pretrained_path is None or cfg.wrapper.reward_classifier_config_file is None:
return None
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
@ -1126,8 +1185,6 @@ def get_classifier(pretrained_path, config_path, device="mps"):
Classifier,
)
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
model = Classifier(classifier_config)
@ -1136,21 +1193,11 @@ def get_classifier(pretrained_path, config_path, device="mps"):
return model
def record_dataset(
env,
repo_id,
root=None,
num_episodes=1,
control_time_s=20,
fps=30,
push_to_hub=True,
task_description="",
policy=None,
):
def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
"""
Record a dataset of robot interactions using either a policy or teleop.
Args:
cfg.
env: The environment to record from
repo_id: Repository ID for dataset storage
root: Local root directory for dataset (optional)
@ -1195,9 +1242,9 @@ def record_dataset(
# Create dataset
dataset = LeRobotDataset.create(
repo_id,
fps,
root=root,
cfg.repo_id,
cfg.fps,
root=cfg.dataset_root,
use_videos=True,
image_writer_threads=4,
image_writer_processes=0,
@ -1206,17 +1253,17 @@ def record_dataset(
# Record episodes
episode_index = 0
while episode_index < num_episodes:
while episode_index < cfg.record_num_episodes:
obs, _ = env.reset()
start_episode_t = time.perf_counter()
log_say(f"Recording episode {episode_index}", play_sounds=True)
# Run episode steps
while time.perf_counter() - start_episode_t < control_time_s:
while time.perf_counter() - start_episode_t < cfg.wrapper.control_time_s:
start_loop_t = time.perf_counter()
# Get action from policy if available
if policy is not None:
if cfg.pretrained_policy_name_or_path is not None:
action = policy.select_action(obs)
# Step environment
@ -1240,9 +1287,9 @@ def record_dataset(
dataset.add_frame(frame)
# Maintain consistent timing
if fps:
if cfg.fps:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
busy_wait(1 / cfg.fps - dt_s)
if terminated or truncated:
break
@ -1253,13 +1300,13 @@ def record_dataset(
logging.info(f"Re-recording episode {episode_index}")
continue
dataset.save_episode(task_description)
dataset.save_episode(cfg.task)
episode_index += 1
# Finalize dataset
dataset.consolidate(run_compute_stats=True)
if push_to_hub:
dataset.push_to_hub(repo_id)
if cfg.push_to_hub:
dataset.push_to_hub(cfg.repo_id)
def replay_episode(env, repo_id, root=None, episode=0):
@ -1282,134 +1329,44 @@ def replay_episode(env, repo_id, root=None, episode=0):
busy_wait(1 / 10 - dt_s)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fps", type=int, default=30, help="control frequency")
parser.add_argument(
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
),
)
parser.add_argument(
"--display-cameras",
help=("Whether to display the camera feed while the rollout is happening"),
)
parser.add_argument(
"--reward-classifier-pretrained-path",
type=str,
default=None,
help="Path to the pretrained classifier weights.",
)
parser.add_argument(
"--reward-classifier-config-file",
type=str,
default=None,
help="Path to a yaml config file that is necessary to build the reward classifier model.",
)
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
parser.add_argument(
"--env-overrides",
type=str,
default=None,
help="Overrides for the env yaml file",
)
parser.add_argument(
"--control-time-s",
type=float,
default=20,
help="Maximum episode length in seconds",
)
parser.add_argument(
"--reset-follower-pos",
type=int,
default=1,
help="Reset follower between episodes",
)
parser.add_argument(
"--replay-repo-id",
type=str,
default=None,
help="Repo ID of the episode to replay",
)
parser.add_argument("--dataset-root", type=str, default=None, help="Root of the dataset to replay")
parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay")
parser.add_argument(
"--record-repo-id",
type=str,
default=None,
help="Repo ID of the dataset to record",
)
parser.add_argument(
"--record-num-episodes",
type=int,
default=1,
help="Number of episodes to record",
)
parser.add_argument(
"--record-episode-task",
type=str,
default="",
help="Single line description of the task to record",
)
@parser.wrap()
def main(cfg: HILSerlRobotEnvConfig):
args = parser.parse_args()
robot = make_robot_from_config(cfg.robot)
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
robot = make_robot(robot_cfg)
reward_classifier = get_classifier(
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
)
reward_classifier = None #get_classifier(
# cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file
# )
user_relative_joint_positions = True
cfg = init_hydra_config(args.env_path, args.env_overrides)
env = make_robot_env(
robot,
reward_classifier,
cfg, # .wrapper,
)
env = make_robot_env(cfg, robot)
if args.record_repo_id is not None:
if cfg.mode == "record":
policy = None
if args.pretrained_policy_name_or_path is not None:
if cfg.pretrained_policy_name_or_path is not None:
from lerobot.common.policies.sac.modeling_sac import SACPolicy
policy = SACPolicy.from_pretrained(args.pretrained_policy_name_or_path)
policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
policy.to(cfg.device)
policy.eval()
record_dataset(
env,
args.record_repo_id,
root=args.dataset_root,
num_episodes=args.record_num_episodes,
fps=args.fps,
task_description=args.record_episode_task,
cfg.repo_id,
root=cfg.dataset_root,
num_episodes=cfg.num_episodes,
fps=cfg.fps,
task_description=cfg.task,
policy=policy,
)
exit()
if args.replay_repo_id is not None:
if cfg.mode == "replay":
replay_episode(
env,
args.replay_repo_id,
root=args.dataset_root,
episode=args.replay_episode,
cfg.replay_repo_id,
root=cfg.dataset_root,
episode=cfg.replay_episode,
)
exit()
@ -1442,7 +1399,10 @@ if __name__ == "__main__":
num_episode += 1
dt_s = time.perf_counter() - start_loop_s
busy_wait(1 / args.fps - dt_s)
busy_wait(1 / cfg.fps - dt_s)
logging.info(f"Success after 20 steps {sucesses}")
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
if __name__ == "__main__":
main()

View File

@ -59,7 +59,6 @@ dependencies = [
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
"imageio[ffmpeg]>=2.34.0",
"jsonlines>=4.0.0",
"mani-skill>=3.0.0b18",
"numba>=0.59.0",
"omegaconf>=2.3.0",
"opencv-python>=4.9.0",