Change config logic in:
- gym_manipulator - find_joint_limits - end_effector_utils
This commit is contained in:
parent
6b18e4f3cf
commit
2c39504109
|
@ -40,13 +40,18 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
if "images" not in key:
|
if "images" not in key:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for imgkey, img in imgs.items():
|
|
||||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||||
|
if not torch.is_tensor(img):
|
||||||
img = torch.from_numpy(img)
|
img = torch.from_numpy(img)
|
||||||
|
|
||||||
|
if img.ndim == 3:
|
||||||
|
img = img.unsqueeze(0)
|
||||||
|
|
||||||
# sanity check that images are channel last
|
# sanity check that images are channel last
|
||||||
_, h, w, c = img.shape
|
_, h, w, c = img.shape
|
||||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
assert c < h and c < w, (
|
||||||
|
f"expect channel last images, but instead got {img.shape=}"
|
||||||
|
)
|
||||||
|
|
||||||
# sanity check that images are uint8
|
# sanity check that images are uint8
|
||||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||||
|
|
|
@ -87,6 +87,8 @@ class RecordControlConfig(ControlConfig):
|
||||||
play_sounds: bool = True
|
play_sounds: bool = True
|
||||||
# Resume recording on an existing dataset.
|
# Resume recording on an existing dataset.
|
||||||
resume: bool = False
|
resume: bool = False
|
||||||
|
# Reset follower arms to an initial configuration.
|
||||||
|
reset_follower_arms: bool = True
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||||
|
|
|
@ -221,7 +221,7 @@ def record_episode(
|
||||||
events=events,
|
events=events,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
fps=fps,
|
fps=fps,
|
||||||
record_delta_actions=record_delta_actions,
|
# record_delta_actions=record_delta_actions,
|
||||||
teleoperate=policy is None,
|
teleoperate=policy is None,
|
||||||
single_task=single_task,
|
single_task=single_task,
|
||||||
)
|
)
|
||||||
|
@ -267,8 +267,8 @@ def control_loop(
|
||||||
|
|
||||||
if teleoperate:
|
if teleoperate:
|
||||||
observation, action = robot.teleop_step(record_data=True)
|
observation, action = robot.teleop_step(record_data=True)
|
||||||
if record_delta_actions:
|
# if record_delta_actions:
|
||||||
action["action"] = action["action"] - current_joint_positions
|
# action["action"] = action["action"] - current_joint_positions
|
||||||
else:
|
else:
|
||||||
observation = robot.capture_observation()
|
observation = robot.capture_observation()
|
||||||
|
|
||||||
|
|
|
@ -443,7 +443,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"main": FeetechMotorsBusConfig(
|
"main": FeetechMotorsBusConfig(
|
||||||
port="/dev/tty.usbmodem58760431091",
|
port="/dev/tty.usbmodem58760433331",
|
||||||
motors={
|
motors={
|
||||||
# name: (index, model)
|
# name: (index, model)
|
||||||
"shoulder_pan": [1, "sts3215"],
|
"shoulder_pan": [1, "sts3215"],
|
||||||
|
@ -460,7 +460,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"main": FeetechMotorsBusConfig(
|
"main": FeetechMotorsBusConfig(
|
||||||
port="/dev/tty.usbmodem585A0076891",
|
port="/dev/tty.usbmodem58760431631",
|
||||||
motors={
|
motors={
|
||||||
# name: (index, model)
|
# name: (index, model)
|
||||||
"shoulder_pan": [1, "sts3215"],
|
"shoulder_pan": [1, "sts3215"],
|
||||||
|
|
|
@ -475,12 +475,12 @@ class ManipulatorRobot:
|
||||||
goal_pos = leader_pos[name]
|
goal_pos = leader_pos[name]
|
||||||
|
|
||||||
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||||
if self.config.joint_position_relative_bounds is not None:
|
# if self.config.joint_position_relative_bounds is not None:
|
||||||
goal_pos = torch.clamp(
|
# goal_pos = torch.clamp(
|
||||||
goal_pos,
|
# goal_pos,
|
||||||
self.config.joint_position_relative_bounds["min"],
|
# self.config.joint_position_relative_bounds["min"],
|
||||||
self.config.joint_position_relative_bounds["max"],
|
# self.config.joint_position_relative_bounds["max"],
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Cap goal position when too far away from present position.
|
# Cap goal position when too far away from present position.
|
||||||
# Slower fps expected due to reading from the follower.
|
# Slower fps expected due to reading from the follower.
|
||||||
|
@ -604,12 +604,12 @@ class ManipulatorRobot:
|
||||||
from_idx = to_idx
|
from_idx = to_idx
|
||||||
|
|
||||||
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||||
if self.config.joint_position_relative_bounds is not None:
|
# if self.config.joint_position_relative_bounds is not None:
|
||||||
goal_pos = torch.clamp(
|
# goal_pos = torch.clamp(
|
||||||
goal_pos,
|
# goal_pos,
|
||||||
self.config.joint_position_relative_bounds["min"],
|
# self.config.joint_position_relative_bounds["min"],
|
||||||
self.config.joint_position_relative_bounds["max"],
|
# self.config.joint_position_relative_bounds["max"],
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Cap goal position when too far away from present position.
|
# Cap goal position when too far away from present position.
|
||||||
# Slower fps expected due to reading from the follower.
|
# Slower fps expected due to reading from the follower.
|
||||||
|
|
|
@ -366,8 +366,8 @@ def replay(
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
|
|
||||||
action = actions[idx]["action"]
|
action = actions[idx]["action"]
|
||||||
if replay_delta_actions:
|
# if replay_delta_actions:
|
||||||
action = action + current_joint_positions
|
# action = action + current_joint_positions
|
||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_episode_t
|
dt_s = time.perf_counter() - start_episode_t
|
||||||
|
@ -394,7 +394,6 @@ def control_robot(cfg: ControlPipelineConfig):
|
||||||
replay(robot, cfg.control)
|
replay(robot, cfg.control)
|
||||||
elif isinstance(cfg.control, RemoteRobotConfig):
|
elif isinstance(cfg.control, RemoteRobotConfig):
|
||||||
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
|
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
|
||||||
|
|
||||||
run_lekiwi(cfg.robot)
|
run_lekiwi(cfg.robot)
|
||||||
|
|
||||||
if robot.is_connected:
|
if robot.is_connected:
|
||||||
|
|
|
@ -1,14 +1,13 @@
|
||||||
import argparse
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
|
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
import argparse
|
||||||
from lerobot.common.robot_devices.utils import busy_wait
|
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.scripts.server.gym_manipulator import make_robot_env, HILSerlRobotEnvConfig
|
||||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
@ -677,15 +676,6 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
|
||||||
# Close the environment
|
# Close the environment
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
def make_robot_from_config(config_path, overrides=None):
|
|
||||||
"""Helper function to create a robot from a config file."""
|
|
||||||
if overrides is None:
|
|
||||||
overrides = []
|
|
||||||
robot_cfg = init_hydra_config(config_path, overrides)
|
|
||||||
return make_robot(robot_cfg)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Test end-effector control")
|
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -703,21 +693,22 @@ if __name__ == "__main__":
|
||||||
help="Control mode to use",
|
help="Control mode to use",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--task",
|
"--robot-type",
|
||||||
type=str,
|
type=str,
|
||||||
default="Robot manipulation task",
|
default="so100",
|
||||||
help="Description of the task being performed",
|
help="Robot type (so100, koch, aloha, etc.)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--push-to-hub",
|
"--config-path",
|
||||||
default=True,
|
type=str,
|
||||||
type=bool,
|
default=None,
|
||||||
help="Push the dataset to Hugging Face Hub",
|
help="Path to the config file in json format",
|
||||||
)
|
)
|
||||||
# Add the rest of your existing arguments
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
robot = make_robot_from_config("lerobot/configs/robot/so100.yaml", [])
|
robot_config = RobotConfig.get_choice_class(args.robot_type)(mock=False)
|
||||||
|
robot = make_robot_from_config(robot_config)
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
@ -743,12 +734,12 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
|
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
|
||||||
# Gym environment control modes
|
# Gym environment control modes
|
||||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
cfg = HILSerlRobotEnvConfig()
|
||||||
|
if args.config_path is not None:
|
||||||
|
cfg = HILSerlRobotEnvConfig.from_json(args.config_path)
|
||||||
|
|
||||||
cfg = init_hydra_config("lerobot/configs/env/so100_real.yaml", [])
|
env = make_robot_env(cfg, robot)
|
||||||
cfg.env.wrapper.ee_action_space_params.use_gamepad = False
|
teleoperate_gym_env(env, controller, fps=args.fps)
|
||||||
env = make_robot_env(robot, None, cfg)
|
|
||||||
teleoperate_gym_env(env, controller)
|
|
||||||
|
|
||||||
elif args.mode == "leader":
|
elif args.mode == "leader":
|
||||||
# Leader-follower modes don't use controllers
|
# Leader-follower modes don't use controllers
|
||||||
|
|
|
@ -5,10 +5,13 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.common.robot_devices.control_utils import is_headless
|
from lerobot.common.robot_devices.control_utils import is_headless
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
|
||||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||||
|
from lerobot.configs import parser
|
||||||
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
|
|
||||||
|
follower_port = "/dev/tty.usbmodem58760431631"
|
||||||
|
leader_port = "/dev/tty.usbmodem58760433331"
|
||||||
|
|
||||||
def find_joint_bounds(
|
def find_joint_bounds(
|
||||||
robot,
|
robot,
|
||||||
|
@ -78,20 +81,28 @@ def find_ee_bounds(
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def make_robot(robot_type="so100", mock=True):
|
||||||
|
"""
|
||||||
|
Create a robot instance using the appropriate robot config class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot_type: Robot type string (e.g., "so100", "koch", "aloha")
|
||||||
|
mock: Whether to use mock mode for hardware (default: True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Robot instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get the appropriate robot config class based on robot_type
|
||||||
|
robot_config = RobotConfig.get_choice_class(robot_type)(mock=mock)
|
||||||
|
robot_config.leader_arms["main"].port = leader_port
|
||||||
|
robot_config.follower_arms["main"].port = follower_port
|
||||||
|
|
||||||
|
return make_robot_from_config(robot_config)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
# Create argparse for script-specific arguments
|
||||||
parser.add_argument(
|
parser = argparse.ArgumentParser(add_help=False) # Set add_help=False to avoid conflict
|
||||||
"--robot-path",
|
|
||||||
type=str,
|
|
||||||
default="lerobot/configs/robot/koch.yaml",
|
|
||||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--robot-overrides",
|
|
||||||
type=str,
|
|
||||||
nargs="*",
|
|
||||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mode",
|
"--mode",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -105,13 +116,29 @@ if __name__ == "__main__":
|
||||||
default=30,
|
default=30,
|
||||||
help="Time step to use for control.",
|
help="Time step to use for control.",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
parser.add_argument(
|
||||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
"--robot-type",
|
||||||
|
type=str,
|
||||||
|
default="so100",
|
||||||
|
help="Robot type (so100, koch, aloha, etc.)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mock",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Use mock mode for hardware simulation",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only parse known args, leaving robot config args for Hydra if used
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
# Create robot with the appropriate config
|
||||||
|
robot = make_robot(args.robot_type, args.mock)
|
||||||
|
|
||||||
robot = make_robot(robot_cfg)
|
|
||||||
if args.mode == "joint":
|
if args.mode == "joint":
|
||||||
find_joint_bounds(robot, args.control_time_s)
|
find_joint_bounds(robot, args.control_time_s)
|
||||||
elif args.mode == "ee":
|
elif args.mode == "ee":
|
||||||
find_ee_bounds(robot, args.control_time_s)
|
find_ee_bounds(robot, args.control_time_s)
|
||||||
|
|
||||||
if robot.is_connected:
|
if robot.is_connected:
|
||||||
robot.disconnect()
|
robot.disconnect()
|
|
@ -1,7 +1,8 @@
|
||||||
import argparse
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import sys
|
||||||
|
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Annotated, Any, Dict, Tuple
|
from typing import Annotated, Any, Dict, Tuple
|
||||||
|
|
||||||
|
@ -9,6 +10,9 @@ import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms.functional as F # noqa: N812
|
import torchvision.transforms.functional as F # noqa: N812
|
||||||
|
import json
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
from lerobot.common.robot_devices.control_utils import (
|
from lerobot.common.robot_devices.control_utils import (
|
||||||
|
@ -16,12 +20,69 @@ from lerobot.common.robot_devices.control_utils import (
|
||||||
is_headless,
|
is_headless,
|
||||||
reset_follower_position,
|
reset_follower_position,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
|
||||||
from lerobot.common.utils.utils import init_hydra_config, log_say
|
from typing import Optional
|
||||||
|
from lerobot.common.utils.utils import log_say
|
||||||
|
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||||
|
|
||||||
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
|
|
||||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||||
|
from lerobot.configs import parser
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EEActionSpaceConfig:
|
||||||
|
"""Configuration parameters for end-effector action space."""
|
||||||
|
x_step_size: float
|
||||||
|
y_step_size: float
|
||||||
|
z_step_size: float
|
||||||
|
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
|
||||||
|
use_gamepad: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EnvWrapperConfig:
|
||||||
|
"""Configuration for environment wrappers."""
|
||||||
|
display_cameras: bool = False
|
||||||
|
delta_action: float = 0.1
|
||||||
|
use_relative_joint_positions: bool = True
|
||||||
|
add_joint_velocity_to_observation: bool = False
|
||||||
|
add_ee_pose_to_observation: bool = False
|
||||||
|
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
|
||||||
|
resize_size: Optional[Tuple[int, int]] = None
|
||||||
|
control_time_s: float = 20.0
|
||||||
|
fixed_reset_joint_positions: Optional[Any] = None
|
||||||
|
reset_time_s: float = 5.0
|
||||||
|
joint_masking_action_space: Optional[Any] = None
|
||||||
|
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||||
|
reward_classifier_pretrained_path: Optional[str] = None
|
||||||
|
reward_classifier_config_file: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HILSerlRobotEnvConfig:
|
||||||
|
"""Configuration for the HILSerlRobotEnv environment."""
|
||||||
|
robot: RobotConfig
|
||||||
|
wrapper: EnvWrapperConfig
|
||||||
|
env_name: str = "real_robot"
|
||||||
|
fps: int = 10
|
||||||
|
mode: str = None # Either "record", "replay", None
|
||||||
|
repo_id: Optional[str] = None
|
||||||
|
dataset_root: Optional[str] = None
|
||||||
|
task: str = ""
|
||||||
|
num_episodes: int = 10 # only for record mode
|
||||||
|
episode: int = 0
|
||||||
|
device: str = "cuda"
|
||||||
|
push_to_hub: bool = True
|
||||||
|
pretrained_policy_name_or_path: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, json_path: str):
|
||||||
|
with open(json_path, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return cls(**config)
|
||||||
|
|
||||||
class HILSerlRobotEnv(gym.Env):
|
class HILSerlRobotEnv(gym.Env):
|
||||||
"""
|
"""
|
||||||
|
@ -49,7 +110,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
|
The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
|
||||||
supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
|
supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
|
||||||
|
|
||||||
Args:
|
cfg.
|
||||||
robot: The robot interface object used to connect and interact with the physical robot.
|
robot: The robot interface object used to connect and interact with the physical robot.
|
||||||
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
|
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
|
||||||
joint positions are used.
|
joint positions are used.
|
||||||
|
@ -77,14 +138,17 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||||
|
|
||||||
# Retrieve the size of the joint position interval bound.
|
# Retrieve the size of the joint position interval bound.
|
||||||
self.relative_bounds_size = (
|
|
||||||
(
|
self.relative_bounds_size = None
|
||||||
self.robot.config.joint_position_relative_bounds["max"]
|
# (
|
||||||
- self.robot.config.joint_position_relative_bounds["min"]
|
# (
|
||||||
)
|
# self.robot.config.joint_position_relative_bounds["max"]
|
||||||
if self.robot.config.joint_position_relative_bounds is not None
|
# - self.robot.config.joint_position_relative_bounds["min"]
|
||||||
else None
|
# )
|
||||||
)
|
# if self.robot.config.joint_position_relative_bounds is not None
|
||||||
|
# else None
|
||||||
|
# )
|
||||||
|
self.robot.config.joint_position_relative_bounds = None
|
||||||
|
|
||||||
self.robot.config.max_relative_target = (
|
self.robot.config.max_relative_target = (
|
||||||
self.relative_bounds_size.float() if self.relative_bounds_size is not None else None
|
self.relative_bounds_size.float() if self.relative_bounds_size is not None else None
|
||||||
|
@ -168,7 +232,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
Reset the environment to its initial state.
|
Reset the environment to its initial state.
|
||||||
This method resets the step counter and clears any episodic data.
|
This method resets the step counter and clears any episodic data.
|
||||||
|
|
||||||
Args:
|
cfg.
|
||||||
seed (Optional[int]): A seed for random number generation to ensure reproducibility.
|
seed (Optional[int]): A seed for random number generation to ensure reproducibility.
|
||||||
options (Optional[dict]): Additional options to influence the reset behavior.
|
options (Optional[dict]): Additional options to influence the reset behavior.
|
||||||
|
|
||||||
|
@ -203,7 +267,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
- When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted
|
- When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted
|
||||||
to relative change based on the current joint positions.
|
to relative change based on the current joint positions.
|
||||||
|
|
||||||
Args:
|
cfg.
|
||||||
action (tuple): A tuple with two elements:
|
action (tuple): A tuple with two elements:
|
||||||
- policy_action (np.ndarray or torch.Tensor): The commanded joint positions.
|
- policy_action (np.ndarray or torch.Tensor): The commanded joint positions.
|
||||||
- intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input.
|
- intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input.
|
||||||
|
@ -258,7 +322,8 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
if teleop_action.dim() == 1:
|
if teleop_action.dim() == 1:
|
||||||
teleop_action = teleop_action.unsqueeze(0)
|
teleop_action = teleop_action.unsqueeze(0)
|
||||||
|
|
||||||
# self.render()
|
if self.display_cameras:
|
||||||
|
self.render()
|
||||||
|
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
|
|
||||||
|
@ -353,7 +418,7 @@ class RewardWrapper(gym.Wrapper):
|
||||||
"""
|
"""
|
||||||
Wrapper to add reward prediction to the environment, it use a trained classifer.
|
Wrapper to add reward prediction to the environment, it use a trained classifer.
|
||||||
|
|
||||||
Args:
|
cfg.
|
||||||
env: The environment to wrap
|
env: The environment to wrap
|
||||||
reward_classifier: The reward classifier model
|
reward_classifier: The reward classifier model
|
||||||
device: The device to run the model on
|
device: The device to run the model on
|
||||||
|
@ -396,7 +461,7 @@ class JointMaskingActionSpace(gym.Wrapper):
|
||||||
"""
|
"""
|
||||||
Wrapper to mask out dimensions of the action space.
|
Wrapper to mask out dimensions of the action space.
|
||||||
|
|
||||||
Args:
|
cfg.
|
||||||
env: The environment to wrap
|
env: The environment to wrap
|
||||||
mask: Binary mask array where 0 indicates dimensions to remove
|
mask: Binary mask array where 0 indicates dimensions to remove
|
||||||
"""
|
"""
|
||||||
|
@ -428,7 +493,7 @@ class JointMaskingActionSpace(gym.Wrapper):
|
||||||
"""
|
"""
|
||||||
Convert masked action back to full action space.
|
Convert masked action back to full action space.
|
||||||
|
|
||||||
Args:
|
cfg.
|
||||||
action: Action in masked space. For Tuple spaces, the first element is masked.
|
action: Action in masked space. For Tuple spaces, the first element is masked.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -859,7 +924,7 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||||
"""
|
"""
|
||||||
Initialize the gamepad controller wrapper.
|
Initialize the gamepad controller wrapper.
|
||||||
|
|
||||||
Args:
|
cfg.
|
||||||
env: The environment to wrap
|
env: The environment to wrap
|
||||||
x_step_size: Base movement step size for X axis in meters
|
x_step_size: Base movement step size for X axis in meters
|
||||||
y_step_size: Base movement step size for Y axis in meters
|
y_step_size: Base movement step size for Y axis in meters
|
||||||
|
@ -937,7 +1002,7 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||||
"""
|
"""
|
||||||
Step the environment, using gamepad input to override actions when active.
|
Step the environment, using gamepad input to override actions when active.
|
||||||
|
|
||||||
Args:
|
cfg.
|
||||||
action: Original action from agent
|
action: Original action from agent
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -1029,25 +1094,19 @@ class ActionScaleWrapper(gym.ActionWrapper):
|
||||||
return action * self.scale_vector, is_intervention
|
return action * self.scale_vector, is_intervention
|
||||||
|
|
||||||
|
|
||||||
def make_robot_env(
|
def make_robot_env(cfg, robot) -> gym.vector.VectorEnv:
|
||||||
robot,
|
|
||||||
reward_classifier,
|
|
||||||
cfg,
|
|
||||||
n_envs: int = 1,
|
|
||||||
) -> gym.vector.VectorEnv:
|
|
||||||
"""
|
"""
|
||||||
Factory function to create a vectorized robot environment.
|
Factory function to create a vectorized robot environment.
|
||||||
|
|
||||||
Args:
|
cfg.
|
||||||
robot: Robot instance to control
|
robot: Robot instance to control
|
||||||
reward_classifier: Classifier model for computing rewards
|
reward_classifier: Classifier model for computing rewards
|
||||||
cfg: Configuration object containing environment parameters
|
cfg: Configuration object containing environment parameters
|
||||||
n_envs: Number of environments to create in parallel. Defaults to 1.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A vectorized gym environment with all the necessary wrappers applied.
|
A vectorized gym environment with all the necessary wrappers applied.
|
||||||
"""
|
"""
|
||||||
if "maniskill" in cfg.env.name:
|
if "maniskill" in cfg.env_name:
|
||||||
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
||||||
|
|
||||||
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
||||||
|
@ -1056,69 +1115,69 @@ def make_robot_env(
|
||||||
n_envs=1,
|
n_envs=1,
|
||||||
)
|
)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
# Create base environment
|
# Create base environment
|
||||||
env = HILSerlRobotEnv(
|
env = HILSerlRobotEnv(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
display_cameras=cfg.env.wrapper.display_cameras,
|
display_cameras=cfg.wrapper.display_cameras,
|
||||||
delta=cfg.env.wrapper.delta_action,
|
delta=cfg.wrapper.delta_action,
|
||||||
use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions
|
use_delta_action_space=cfg.wrapper.use_relative_joint_positions
|
||||||
and cfg.env.wrapper.ee_action_space_params is None,
|
and cfg.wrapper.ee_action_space_params is None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add observation and image processing
|
# Add observation and image processing
|
||||||
if cfg.env.wrapper.add_joint_velocity_to_observation:
|
if cfg.wrapper.add_joint_velocity_to_observation:
|
||||||
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
|
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
|
||||||
if cfg.env.wrapper.add_ee_pose_to_observation:
|
if cfg.wrapper.add_ee_pose_to_observation:
|
||||||
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds)
|
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
|
||||||
|
|
||||||
env = ConvertToLeRobotObservation(env=env, device=cfg.env.device)
|
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
||||||
|
|
||||||
if cfg.env.wrapper.crop_params_dict is not None:
|
if cfg.wrapper.crop_params_dict is not None:
|
||||||
env = ImageCropResizeWrapper(
|
env = ImageCropResizeWrapper(
|
||||||
env=env,
|
env=env,
|
||||||
crop_params_dict=cfg.env.wrapper.crop_params_dict,
|
crop_params_dict=cfg.wrapper.crop_params_dict,
|
||||||
resize_size=cfg.env.wrapper.resize_size,
|
resize_size=cfg.wrapper.resize_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add reward computation and control wrappers
|
# Add reward computation and control wrappers
|
||||||
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps)
|
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||||
if cfg.env.wrapper.ee_action_space_params is not None:
|
if cfg.wrapper.ee_action_space_params is not None:
|
||||||
env = EEActionWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params)
|
env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
||||||
if (
|
if (
|
||||||
cfg.env.wrapper.ee_action_space_params is not None
|
cfg.wrapper.ee_action_space_params is not None
|
||||||
and cfg.env.wrapper.ee_action_space_params.use_gamepad
|
and cfg.wrapper.ee_action_space_params.use_gamepad
|
||||||
):
|
):
|
||||||
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params)
|
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
||||||
env = GamepadControlWrapper(
|
env = GamepadControlWrapper(
|
||||||
env=env,
|
env=env,
|
||||||
x_step_size=cfg.env.wrapper.ee_action_space_params.x_step_size,
|
x_step_size=cfg.wrapper.ee_action_space_params.x_step_size,
|
||||||
y_step_size=cfg.env.wrapper.ee_action_space_params.y_step_size,
|
y_step_size=cfg.wrapper.ee_action_space_params.y_step_size,
|
||||||
z_step_size=cfg.env.wrapper.ee_action_space_params.z_step_size,
|
z_step_size=cfg.wrapper.ee_action_space_params.z_step_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
env = KeyboardInterfaceWrapper(env=env)
|
env = KeyboardInterfaceWrapper(env=env)
|
||||||
|
|
||||||
env = ResetWrapper(
|
env = ResetWrapper(
|
||||||
env=env,
|
env=env,
|
||||||
reset_pose=cfg.env.wrapper.fixed_reset_joint_positions,
|
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
|
||||||
reset_time_s=cfg.env.wrapper.reset_time_s,
|
reset_time_s=cfg.wrapper.reset_time_s,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
cfg.env.wrapper.ee_action_space_params is None
|
cfg.wrapper.ee_action_space_params is None
|
||||||
and cfg.env.wrapper.joint_masking_action_space is not None
|
and cfg.wrapper.joint_masking_action_space is not None
|
||||||
):
|
):
|
||||||
env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space)
|
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||||
env = BatchCompitableWrapper(env=env)
|
env = BatchCompitableWrapper(env=env)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
def get_classifier(pretrained_path, config_path, device="mps"):
|
def get_classifier(cfg):
|
||||||
if pretrained_path is None or config_path is None:
|
if cfg.wrapper.reward_classifier_pretrained_path is None or cfg.wrapper.reward_classifier_config_file is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
|
||||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||||
ClassifierConfig,
|
ClassifierConfig,
|
||||||
)
|
)
|
||||||
|
@ -1126,8 +1185,6 @@ def get_classifier(pretrained_path, config_path, device="mps"):
|
||||||
Classifier,
|
Classifier,
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg = init_hydra_config(config_path)
|
|
||||||
|
|
||||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||||
model = Classifier(classifier_config)
|
model = Classifier(classifier_config)
|
||||||
|
@ -1136,21 +1193,11 @@ def get_classifier(pretrained_path, config_path, device="mps"):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def record_dataset(
|
def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
|
||||||
env,
|
|
||||||
repo_id,
|
|
||||||
root=None,
|
|
||||||
num_episodes=1,
|
|
||||||
control_time_s=20,
|
|
||||||
fps=30,
|
|
||||||
push_to_hub=True,
|
|
||||||
task_description="",
|
|
||||||
policy=None,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Record a dataset of robot interactions using either a policy or teleop.
|
Record a dataset of robot interactions using either a policy or teleop.
|
||||||
|
|
||||||
Args:
|
cfg.
|
||||||
env: The environment to record from
|
env: The environment to record from
|
||||||
repo_id: Repository ID for dataset storage
|
repo_id: Repository ID for dataset storage
|
||||||
root: Local root directory for dataset (optional)
|
root: Local root directory for dataset (optional)
|
||||||
|
@ -1195,9 +1242,9 @@ def record_dataset(
|
||||||
|
|
||||||
# Create dataset
|
# Create dataset
|
||||||
dataset = LeRobotDataset.create(
|
dataset = LeRobotDataset.create(
|
||||||
repo_id,
|
cfg.repo_id,
|
||||||
fps,
|
cfg.fps,
|
||||||
root=root,
|
root=cfg.dataset_root,
|
||||||
use_videos=True,
|
use_videos=True,
|
||||||
image_writer_threads=4,
|
image_writer_threads=4,
|
||||||
image_writer_processes=0,
|
image_writer_processes=0,
|
||||||
|
@ -1206,17 +1253,17 @@ def record_dataset(
|
||||||
|
|
||||||
# Record episodes
|
# Record episodes
|
||||||
episode_index = 0
|
episode_index = 0
|
||||||
while episode_index < num_episodes:
|
while episode_index < cfg.record_num_episodes:
|
||||||
obs, _ = env.reset()
|
obs, _ = env.reset()
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
log_say(f"Recording episode {episode_index}", play_sounds=True)
|
log_say(f"Recording episode {episode_index}", play_sounds=True)
|
||||||
|
|
||||||
# Run episode steps
|
# Run episode steps
|
||||||
while time.perf_counter() - start_episode_t < control_time_s:
|
while time.perf_counter() - start_episode_t < cfg.wrapper.control_time_s:
|
||||||
start_loop_t = time.perf_counter()
|
start_loop_t = time.perf_counter()
|
||||||
|
|
||||||
# Get action from policy if available
|
# Get action from policy if available
|
||||||
if policy is not None:
|
if cfg.pretrained_policy_name_or_path is not None:
|
||||||
action = policy.select_action(obs)
|
action = policy.select_action(obs)
|
||||||
|
|
||||||
# Step environment
|
# Step environment
|
||||||
|
@ -1240,9 +1287,9 @@ def record_dataset(
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
# Maintain consistent timing
|
# Maintain consistent timing
|
||||||
if fps:
|
if cfg.fps:
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / cfg.fps - dt_s)
|
||||||
|
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
break
|
break
|
||||||
|
@ -1253,13 +1300,13 @@ def record_dataset(
|
||||||
logging.info(f"Re-recording episode {episode_index}")
|
logging.info(f"Re-recording episode {episode_index}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
dataset.save_episode(task_description)
|
dataset.save_episode(cfg.task)
|
||||||
episode_index += 1
|
episode_index += 1
|
||||||
|
|
||||||
# Finalize dataset
|
# Finalize dataset
|
||||||
dataset.consolidate(run_compute_stats=True)
|
dataset.consolidate(run_compute_stats=True)
|
||||||
if push_to_hub:
|
if cfg.push_to_hub:
|
||||||
dataset.push_to_hub(repo_id)
|
dataset.push_to_hub(cfg.repo_id)
|
||||||
|
|
||||||
|
|
||||||
def replay_episode(env, repo_id, root=None, episode=0):
|
def replay_episode(env, repo_id, root=None, episode=0):
|
||||||
|
@ -1282,134 +1329,44 @@ def replay_episode(env, repo_id, root=None, episode=0):
|
||||||
busy_wait(1 / 10 - dt_s)
|
busy_wait(1 / 10 - dt_s)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
@parser.wrap()
|
||||||
parser = argparse.ArgumentParser()
|
def main(cfg: HILSerlRobotEnvConfig):
|
||||||
parser.add_argument("--fps", type=int, default=30, help="control frequency")
|
|
||||||
parser.add_argument(
|
|
||||||
"--robot-path",
|
|
||||||
type=str,
|
|
||||||
default="lerobot/configs/robot/koch.yaml",
|
|
||||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--robot-overrides",
|
|
||||||
type=str,
|
|
||||||
nargs="*",
|
|
||||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-p",
|
|
||||||
"--pretrained-policy-name-or-path",
|
|
||||||
help=(
|
|
||||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
|
||||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--display-cameras",
|
|
||||||
help=("Whether to display the camera feed while the rollout is happening"),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--reward-classifier-pretrained-path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to the pretrained classifier weights.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--reward-classifier-config-file",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
|
|
||||||
parser.add_argument(
|
|
||||||
"--env-overrides",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Overrides for the env yaml file",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--control-time-s",
|
|
||||||
type=float,
|
|
||||||
default=20,
|
|
||||||
help="Maximum episode length in seconds",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--reset-follower-pos",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Reset follower between episodes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--replay-repo-id",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Repo ID of the episode to replay",
|
|
||||||
)
|
|
||||||
parser.add_argument("--dataset-root", type=str, default=None, help="Root of the dataset to replay")
|
|
||||||
parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay")
|
|
||||||
parser.add_argument(
|
|
||||||
"--record-repo-id",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Repo ID of the dataset to record",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--record-num-episodes",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of episodes to record",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--record-episode-task",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Single line description of the task to record",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
robot = make_robot_from_config(cfg.robot)
|
||||||
|
|
||||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
reward_classifier = None #get_classifier(
|
||||||
robot = make_robot(robot_cfg)
|
# cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file
|
||||||
|
# )
|
||||||
reward_classifier = get_classifier(
|
|
||||||
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
|
|
||||||
)
|
|
||||||
user_relative_joint_positions = True
|
user_relative_joint_positions = True
|
||||||
|
|
||||||
cfg = init_hydra_config(args.env_path, args.env_overrides)
|
env = make_robot_env(cfg, robot)
|
||||||
env = make_robot_env(
|
|
||||||
robot,
|
|
||||||
reward_classifier,
|
|
||||||
cfg, # .wrapper,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.record_repo_id is not None:
|
if cfg.mode == "record":
|
||||||
policy = None
|
policy = None
|
||||||
if args.pretrained_policy_name_or_path is not None:
|
if cfg.pretrained_policy_name_or_path is not None:
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
|
|
||||||
policy = SACPolicy.from_pretrained(args.pretrained_policy_name_or_path)
|
policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
|
||||||
policy.to(cfg.device)
|
policy.to(cfg.device)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
record_dataset(
|
record_dataset(
|
||||||
env,
|
env,
|
||||||
args.record_repo_id,
|
cfg.repo_id,
|
||||||
root=args.dataset_root,
|
root=cfg.dataset_root,
|
||||||
num_episodes=args.record_num_episodes,
|
num_episodes=cfg.num_episodes,
|
||||||
fps=args.fps,
|
fps=cfg.fps,
|
||||||
task_description=args.record_episode_task,
|
task_description=cfg.task,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
)
|
)
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
if args.replay_repo_id is not None:
|
if cfg.mode == "replay":
|
||||||
replay_episode(
|
replay_episode(
|
||||||
env,
|
env,
|
||||||
args.replay_repo_id,
|
cfg.replay_repo_id,
|
||||||
root=args.dataset_root,
|
root=cfg.dataset_root,
|
||||||
episode=args.replay_episode,
|
episode=cfg.replay_episode,
|
||||||
)
|
)
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
|
@ -1442,7 +1399,10 @@ if __name__ == "__main__":
|
||||||
num_episode += 1
|
num_episode += 1
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_s
|
dt_s = time.perf_counter() - start_loop_s
|
||||||
busy_wait(1 / args.fps - dt_s)
|
busy_wait(1 / cfg.fps - dt_s)
|
||||||
|
|
||||||
logging.info(f"Success after 20 steps {sucesses}")
|
logging.info(f"Success after 20 steps {sucesses}")
|
||||||
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
|
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -59,7 +59,6 @@ dependencies = [
|
||||||
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
|
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
|
||||||
"imageio[ffmpeg]>=2.34.0",
|
"imageio[ffmpeg]>=2.34.0",
|
||||||
"jsonlines>=4.0.0",
|
"jsonlines>=4.0.0",
|
||||||
"mani-skill>=3.0.0b18",
|
|
||||||
"numba>=0.59.0",
|
"numba>=0.59.0",
|
||||||
"omegaconf>=2.3.0",
|
"omegaconf>=2.3.0",
|
||||||
"opencv-python>=4.9.0",
|
"opencv-python>=4.9.0",
|
||||||
|
|
Loading…
Reference in New Issue