Add end effector action space to hil-serl (#861)
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
7960f2c3c1
commit
b82faf7d8c
|
@ -135,6 +135,7 @@ class PixelWrapper(gym.Wrapper):
|
||||||
return self._get_obs(obs), reward, terminated, truncated, info
|
return self._get_obs(obs), reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Remove this
|
||||||
class ConvertToLeRobotEnv(gym.Wrapper):
|
class ConvertToLeRobotEnv(gym.Wrapper):
|
||||||
def __init__(self, env, num_envs):
|
def __init__(self, env, num_envs):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
|
@ -103,6 +103,6 @@ class SACConfig:
|
||||||
"use_tanh_squash": True,
|
"use_tanh_squash": True,
|
||||||
"log_std_min": -5,
|
"log_std_min": -5,
|
||||||
"log_std_max": 2,
|
"log_std_max": 2,
|
||||||
"init_final": 0.005,
|
"init_final": 0.05,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -334,7 +334,7 @@ def reset_environment(robot, events, reset_time_s, fps):
|
||||||
def reset_follower_position(robot: Robot, target_position):
|
def reset_follower_position(robot: Robot, target_position):
|
||||||
current_position = robot.follower_arms["main"].read("Present_Position")
|
current_position = robot.follower_arms["main"].read("Present_Position")
|
||||||
trajectory = torch.from_numpy(
|
trajectory = torch.from_numpy(
|
||||||
np.linspace(current_position, target_position, 30)
|
np.linspace(current_position, target_position, 50)
|
||||||
) # NOTE: 30 is just an aribtrary number
|
) # NOTE: 30 is just an aribtrary number
|
||||||
for pose in trajectory:
|
for pose in trajectory:
|
||||||
robot.send_action(pose)
|
robot.send_action(pose)
|
||||||
|
|
|
@ -5,26 +5,46 @@ fps: 10
|
||||||
env:
|
env:
|
||||||
name: real_world
|
name: real_world
|
||||||
task: null
|
task: null
|
||||||
state_dim: 6
|
state_dim: 15
|
||||||
action_dim: 6
|
action_dim: 3
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
device: mps
|
device: mps
|
||||||
|
|
||||||
wrapper:
|
wrapper:
|
||||||
crop_params_dict:
|
crop_params_dict:
|
||||||
observation.images.front: [102, 43, 358, 523]
|
observation.images.front: [171, 207, 116, 251]
|
||||||
observation.images.side: [92, 123, 379, 349]
|
observation.images.side: [232, 200, 142, 204]
|
||||||
# observation.images.front: [109, 37, 361, 557]
|
|
||||||
# observation.images.side: [94, 161, 372, 315]
|
|
||||||
resize_size: [128, 128]
|
resize_size: [128, 128]
|
||||||
control_time_s: 20
|
control_time_s: 10
|
||||||
reset_follower_pos: true
|
reset_follower_pos: false
|
||||||
use_relative_joint_positions: true
|
use_relative_joint_positions: true
|
||||||
reset_time_s: 5
|
reset_time_s: 5
|
||||||
display_cameras: false
|
display_cameras: false
|
||||||
delta_action: 0.1
|
delta_action: null #0.3
|
||||||
joint_masking_action_space: [1, 1, 1, 1, 0, 0] # disable wrist and gripper
|
joint_masking_action_space: null #[1, 1, 1, 1, 0, 0] # disable wrist and gripper
|
||||||
|
add_joint_velocity_to_observation: true
|
||||||
|
add_ee_pose_to_observation: true
|
||||||
|
|
||||||
|
# If null then the teleoperation will be used to reset the robot
|
||||||
|
# Bounds for pushcube_gamepad_lerobot15 dataset and experiments
|
||||||
|
# fixed_reset_joint_positions: [-19.86, 103.19, 117.33, 42.7, 13.89, 0.297]
|
||||||
|
# ee_action_space_params: # If null then ee_action_space is not used
|
||||||
|
# bounds:
|
||||||
|
# max: [0.291, 0.147, 0.074]
|
||||||
|
# min: [0.139, -0.143, 0.03]
|
||||||
|
|
||||||
|
# Bounds for insertcube_gamepad dataset and experiments
|
||||||
|
fixed_reset_joint_positions: [20.0, 90., 90., 75., -0.7910156, -0.5673759]
|
||||||
|
ee_action_space_params:
|
||||||
|
bounds:
|
||||||
|
max: [0.25295413, 0.07498981, 0.06862044]
|
||||||
|
min: [0.2010096, -0.12, 0.0433196]
|
||||||
|
|
||||||
|
use_gamepad: true
|
||||||
|
x_step_size: 0.03
|
||||||
|
y_step_size: 0.03
|
||||||
|
z_step_size: 0.03
|
||||||
|
|
||||||
reward_classifier:
|
reward_classifier:
|
||||||
pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
|
pretrained_path: null # outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
|
||||||
config_path: lerobot/configs/policy/hilserl_classifier.yaml
|
config_path: null # lerobot/configs/policy/hilserl_classifier.yaml
|
||||||
|
|
|
@ -8,8 +8,7 @@
|
||||||
# env.gym.obs_type=environment_state_agent_pos \
|
# env.gym.obs_type=environment_state_agent_pos \
|
||||||
|
|
||||||
seed: 1
|
seed: 1
|
||||||
dataset_repo_id: aractingi/push_cube_overfit_cropped_resized
|
dataset_repo_id: aractingi/insertcube_simple
|
||||||
#aractingi/push_cube_square_offline_demo_cropped_resized
|
|
||||||
|
|
||||||
training:
|
training:
|
||||||
# Offline training dataloader
|
# Offline training dataloader
|
||||||
|
@ -30,7 +29,7 @@ training:
|
||||||
online_steps_between_rollouts: 1000
|
online_steps_between_rollouts: 1000
|
||||||
online_sampling_ratio: 1.0
|
online_sampling_ratio: 1.0
|
||||||
online_env_seed: 10000
|
online_env_seed: 10000
|
||||||
online_buffer_capacity: 1000000
|
online_buffer_capacity: 10000
|
||||||
online_buffer_seed_size: 0
|
online_buffer_seed_size: 0
|
||||||
online_step_before_learning: 100 #5000
|
online_step_before_learning: 100 #5000
|
||||||
do_online_rollout_async: false
|
do_online_rollout_async: false
|
||||||
|
@ -62,7 +61,7 @@ policy:
|
||||||
observation.images.side: [3, 128, 128]
|
observation.images.side: [3, 128, 128]
|
||||||
# observation.image: [3, 128, 128]
|
# observation.image: [3, 128, 128]
|
||||||
output_shapes:
|
output_shapes:
|
||||||
action: [4] # ["${env.action_dim}"]
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
input_normalization_modes:
|
input_normalization_modes:
|
||||||
|
@ -77,23 +76,16 @@ policy:
|
||||||
mean: [0.485, 0.456, 0.406]
|
mean: [0.485, 0.456, 0.406]
|
||||||
std: [0.229, 0.224, 0.225]
|
std: [0.229, 0.224, 0.225]
|
||||||
observation.state:
|
observation.state:
|
||||||
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
|
# 6- joint positions, 6- joint velocities, 3- ee position
|
||||||
max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
|
max: [ 52.822266, 136.14258, 142.03125, 72.1582, 22.675781, -0.5673759, 100., 100., 100., 100., 100., 100., 0.25295413, 0.07498981, 0.06862044]
|
||||||
|
min: [-2.6367188, 86.572266, 89.82422, 12.392578, -26.015625, -0.5673759, -100., -100., -100., -100., -100., -100., 0.2010096, -0.12, 0.0433196]
|
||||||
# min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
|
|
||||||
# max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
|
|
||||||
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
|
|
||||||
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
|
|
||||||
|
|
||||||
output_normalization_modes:
|
output_normalization_modes:
|
||||||
action: min_max
|
action: min_max
|
||||||
output_normalization_params:
|
output_normalization_params:
|
||||||
# action:
|
|
||||||
# min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
|
|
||||||
# max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
|
||||||
action:
|
action:
|
||||||
min: [-149.23828125, -97.734375, -100.1953125, -73.740234375]
|
min: [-0.03, -0.03, -0.01]
|
||||||
max: [149.23828125, 97.734375, 100.1953125, 73.740234375]
|
max: [0.03, 0.03, 0.03]
|
||||||
|
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
# Neural networks.
|
# Neural networks.
|
||||||
|
|
|
@ -14,9 +14,13 @@ calibration_dir: .cache/calibration/so100
|
||||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||||
# the number of motors in your follower arms.
|
# the number of motors in your follower arms.
|
||||||
max_relative_target: null
|
max_relative_target: null
|
||||||
joint_position_relative_bounds:
|
joint_position_relative_bounds: null
|
||||||
max: [ 7.2158203e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
|
# max: [100, 100, 100, 100, 100, 100]
|
||||||
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
|
# min: [-100, -100, -100, -100, -100, -100]
|
||||||
|
# max: [ 7.2158203e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
|
||||||
|
# min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
|
||||||
|
# max: [ 35.06836 , 103.18359 , 127.61719 , 75.58594 , 0., 0.]
|
||||||
|
# min: [ -8.876953 , 63.808594 , 90.49805 , 49.48242 , 0., 0.]
|
||||||
|
|
||||||
leader_arms:
|
leader_arms:
|
||||||
main:
|
main:
|
||||||
|
@ -47,13 +51,13 @@ follower_arms:
|
||||||
cameras:
|
cameras:
|
||||||
front:
|
front:
|
||||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||||
camera_index: 0
|
camera_index: 1
|
||||||
fps: 30
|
fps: 30
|
||||||
width: 640
|
width: 640
|
||||||
height: 480
|
height: 480
|
||||||
side:
|
side:
|
||||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||||
camera_index: 1
|
camera_index: 0
|
||||||
fps: 30
|
fps: 30
|
||||||
width: 640
|
width: 640
|
||||||
height: 480
|
height: 480
|
||||||
|
|
|
@ -54,6 +54,7 @@ from lerobot.scripts.server.network_utils import (
|
||||||
)
|
)
|
||||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||||
from lerobot.scripts.server import learner_service
|
from lerobot.scripts.server import learner_service
|
||||||
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
|
|
||||||
from torch.multiprocessing import Queue, Event
|
from torch.multiprocessing import Queue, Event
|
||||||
from queue import Empty
|
from queue import Empty
|
||||||
|
@ -312,17 +313,6 @@ def act_with_policy(
|
||||||
|
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
|
|
||||||
# HACK: This is an ugly hack to pass the normalization parameters to the policy
|
|
||||||
# Because the action space is dynamic so we override the output normalization parameters
|
|
||||||
# it's ugly, we know ... and we will fix it
|
|
||||||
min_action_space: list = online_env.action_space.spaces[0].low.tolist()
|
|
||||||
max_action_space: list = online_env.action_space.spaces[0].high.tolist()
|
|
||||||
output_normalization_params: dict[dict[str, list]] = {
|
|
||||||
"action": {"min": min_action_space, "max": max_action_space}
|
|
||||||
}
|
|
||||||
cfg.policy.output_normalization_params = output_normalization_params
|
|
||||||
cfg.policy.output_shapes["action"] = online_env.action_space.spaces[0].shape
|
|
||||||
|
|
||||||
### Instantiate the policy in both the actor and learner processes
|
### Instantiate the policy in both the actor and learner processes
|
||||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||||
|
@ -347,6 +337,7 @@ def act_with_policy(
|
||||||
episode_intervention = False
|
episode_intervention = False
|
||||||
|
|
||||||
for interaction_step in range(cfg.training.online_steps):
|
for interaction_step in range(cfg.training.online_steps):
|
||||||
|
start_time = time.perf_counter()
|
||||||
if shutdown_event.is_set():
|
if shutdown_event.is_set():
|
||||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||||
return
|
return
|
||||||
|
@ -408,7 +399,6 @@ def act_with_policy(
|
||||||
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
|
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# assign obs to the next obs and continue the rollout
|
# assign obs to the next obs and continue the rollout
|
||||||
obs = next_obs
|
obs = next_obs
|
||||||
|
|
||||||
|
@ -449,6 +439,10 @@ def act_with_policy(
|
||||||
episode_intervention = False
|
episode_intervention = False
|
||||||
obs, info = online_env.reset()
|
obs, info = online_env.reset()
|
||||||
|
|
||||||
|
if cfg.fps is not None:
|
||||||
|
dt_time = time.perf_counter() - start_time
|
||||||
|
busy_wait(1 / cfg.fps - dt_time)
|
||||||
|
|
||||||
|
|
||||||
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||||
"""Send transitions to learner in smaller chunks to avoid network issues.
|
"""Send transitions to learner in smaller chunks to avoid network issues.
|
||||||
|
|
|
@ -263,11 +263,6 @@ if __name__ == "__main__":
|
||||||
with open(args.crop_params_path) as f:
|
with open(args.crop_params_path) as f:
|
||||||
rois = json.load(f)
|
rois = json.load(f)
|
||||||
|
|
||||||
# rois = {
|
|
||||||
# "observation.images.front": [102, 43, 358, 523],
|
|
||||||
# "observation.images.side": [92, 123, 379, 349],
|
|
||||||
# }
|
|
||||||
|
|
||||||
# Print the selected rectangular ROIs
|
# Print the selected rectangular ROIs
|
||||||
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
||||||
for key, roi in rois.items():
|
for key, roi in rois.items():
|
||||||
|
|
|
@ -0,0 +1,797 @@
|
||||||
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
|
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
class InputController:
|
||||||
|
"""Base class for input controllers that generate motion deltas."""
|
||||||
|
|
||||||
|
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
|
||||||
|
"""
|
||||||
|
Initialize the controller.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_step_size: Base movement step size in meters
|
||||||
|
y_step_size: Base movement step size in meters
|
||||||
|
z_step_size: Base movement step size in meters
|
||||||
|
"""
|
||||||
|
self.x_step_size = x_step_size
|
||||||
|
self.y_step_size = y_step_size
|
||||||
|
self.z_step_size = z_step_size
|
||||||
|
self.running = True
|
||||||
|
self.episode_end_status = None # None, "success", or "failure"
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start the controller and initialize resources."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop the controller and release resources."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_deltas(self):
|
||||||
|
"""Get the current movement deltas (dx, dy, dz) in meters."""
|
||||||
|
return 0.0, 0.0, 0.0
|
||||||
|
|
||||||
|
def should_quit(self):
|
||||||
|
"""Return True if the user has requested to quit."""
|
||||||
|
return not self.running
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
"""Update controller state - call this once per frame."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Support for use in 'with' statements."""
|
||||||
|
self.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Ensure resources are released when exiting 'with' block."""
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
def get_episode_end_status(self):
|
||||||
|
"""
|
||||||
|
Get the current episode end status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None if episode should continue, "success" or "failure" otherwise
|
||||||
|
"""
|
||||||
|
status = self.episode_end_status
|
||||||
|
self.episode_end_status = None # Reset after reading
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
class KeyboardController(InputController):
|
||||||
|
"""Generate motion deltas from keyboard input."""
|
||||||
|
|
||||||
|
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
|
||||||
|
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||||
|
self.key_states = {
|
||||||
|
"forward_x": False,
|
||||||
|
"backward_x": False,
|
||||||
|
"forward_y": False,
|
||||||
|
"backward_y": False,
|
||||||
|
"forward_z": False,
|
||||||
|
"backward_z": False,
|
||||||
|
"quit": False,
|
||||||
|
"success": False,
|
||||||
|
"failure": False,
|
||||||
|
}
|
||||||
|
self.listener = None
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start the keyboard listener."""
|
||||||
|
from pynput import keyboard
|
||||||
|
|
||||||
|
def on_press(key):
|
||||||
|
try:
|
||||||
|
if key == keyboard.Key.up:
|
||||||
|
self.key_states["forward_x"] = True
|
||||||
|
elif key == keyboard.Key.down:
|
||||||
|
self.key_states["backward_x"] = True
|
||||||
|
elif key == keyboard.Key.left:
|
||||||
|
self.key_states["forward_y"] = True
|
||||||
|
elif key == keyboard.Key.right:
|
||||||
|
self.key_states["backward_y"] = True
|
||||||
|
elif key == keyboard.Key.shift:
|
||||||
|
self.key_states["backward_z"] = True
|
||||||
|
elif key == keyboard.Key.shift_r:
|
||||||
|
self.key_states["forward_z"] = True
|
||||||
|
elif key == keyboard.Key.esc:
|
||||||
|
self.key_states["quit"] = True
|
||||||
|
self.running = False
|
||||||
|
return False
|
||||||
|
elif key == keyboard.Key.enter:
|
||||||
|
self.key_states["success"] = True
|
||||||
|
self.episode_end_status = "success"
|
||||||
|
elif key == keyboard.Key.backspace:
|
||||||
|
self.key_states["failure"] = True
|
||||||
|
self.episode_end_status = "failure"
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_release(key):
|
||||||
|
try:
|
||||||
|
if key == keyboard.Key.up:
|
||||||
|
self.key_states["forward_x"] = False
|
||||||
|
elif key == keyboard.Key.down:
|
||||||
|
self.key_states["backward_x"] = False
|
||||||
|
elif key == keyboard.Key.left:
|
||||||
|
self.key_states["forward_y"] = False
|
||||||
|
elif key == keyboard.Key.right:
|
||||||
|
self.key_states["backward_y"] = False
|
||||||
|
elif key == keyboard.Key.shift:
|
||||||
|
self.key_states["backward_z"] = False
|
||||||
|
elif key == keyboard.Key.shift_r:
|
||||||
|
self.key_states["forward_z"] = False
|
||||||
|
elif key == keyboard.Key.enter:
|
||||||
|
self.key_states["success"] = False
|
||||||
|
elif key == keyboard.Key.backspace:
|
||||||
|
self.key_states["failure"] = False
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.listener = keyboard.Listener(on_press=on_press, on_release=on_release)
|
||||||
|
self.listener.start()
|
||||||
|
|
||||||
|
print("Keyboard controls:")
|
||||||
|
print(" Arrow keys: Move in X-Y plane")
|
||||||
|
print(" Shift and Shift_R: Move in Z axis")
|
||||||
|
print(" Enter: End episode with SUCCESS")
|
||||||
|
print(" Backspace: End episode with FAILURE")
|
||||||
|
print(" ESC: Exit")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop the keyboard listener."""
|
||||||
|
if self.listener and self.listener.is_alive():
|
||||||
|
self.listener.stop()
|
||||||
|
|
||||||
|
def get_deltas(self):
|
||||||
|
"""Get the current movement deltas from keyboard state."""
|
||||||
|
delta_x = delta_y = delta_z = 0.0
|
||||||
|
|
||||||
|
if self.key_states["forward_x"]:
|
||||||
|
delta_x += self.x_step_size
|
||||||
|
if self.key_states["backward_x"]:
|
||||||
|
delta_x -= self.x_step_size
|
||||||
|
if self.key_states["forward_y"]:
|
||||||
|
delta_y += self.y_step_size
|
||||||
|
if self.key_states["backward_y"]:
|
||||||
|
delta_y -= self.y_step_size
|
||||||
|
if self.key_states["forward_z"]:
|
||||||
|
delta_z += self.z_step_size
|
||||||
|
if self.key_states["backward_z"]:
|
||||||
|
delta_z -= self.z_step_size
|
||||||
|
|
||||||
|
return delta_x, delta_y, delta_z
|
||||||
|
|
||||||
|
def should_quit(self):
|
||||||
|
"""Return True if ESC was pressed."""
|
||||||
|
return self.key_states["quit"]
|
||||||
|
|
||||||
|
def should_save(self):
|
||||||
|
"""Return True if Enter was pressed (save episode)."""
|
||||||
|
return self.key_states["success"] or self.key_states["failure"]
|
||||||
|
|
||||||
|
|
||||||
|
class GamepadController(InputController):
|
||||||
|
"""Generate motion deltas from gamepad input."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1
|
||||||
|
):
|
||||||
|
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||||
|
self.deadzone = deadzone
|
||||||
|
self.joystick = None
|
||||||
|
self.intervention_flag = False
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Initialize pygame and the gamepad."""
|
||||||
|
import pygame
|
||||||
|
|
||||||
|
pygame.init()
|
||||||
|
pygame.joystick.init()
|
||||||
|
|
||||||
|
if pygame.joystick.get_count() == 0:
|
||||||
|
logging.error(
|
||||||
|
"No gamepad detected. Please connect a gamepad and try again."
|
||||||
|
)
|
||||||
|
self.running = False
|
||||||
|
return
|
||||||
|
|
||||||
|
self.joystick = pygame.joystick.Joystick(0)
|
||||||
|
self.joystick.init()
|
||||||
|
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
|
||||||
|
|
||||||
|
print("Gamepad controls:")
|
||||||
|
print(" Left analog stick: Move in X-Y plane")
|
||||||
|
print(" Right analog stick (vertical): Move in Z axis")
|
||||||
|
print(" B/Circle button: Exit")
|
||||||
|
print(" Y/Triangle button: End episode with SUCCESS")
|
||||||
|
print(" A/Cross button: End episode with FAILURE")
|
||||||
|
print(" X/Square button: Rerecord episode")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Clean up pygame resources."""
|
||||||
|
import pygame
|
||||||
|
|
||||||
|
if pygame.joystick.get_init():
|
||||||
|
if self.joystick:
|
||||||
|
self.joystick.quit()
|
||||||
|
pygame.joystick.quit()
|
||||||
|
pygame.quit()
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
"""Process pygame events to get fresh gamepad readings."""
|
||||||
|
import pygame
|
||||||
|
|
||||||
|
for event in pygame.event.get():
|
||||||
|
if event.type == pygame.JOYBUTTONDOWN:
|
||||||
|
if event.button == 3:
|
||||||
|
self.episode_end_status = "success"
|
||||||
|
# A button (1) for failure
|
||||||
|
elif event.button == 1:
|
||||||
|
self.episode_end_status = "failure"
|
||||||
|
# X button (0) for rerecord
|
||||||
|
elif event.button == 0:
|
||||||
|
self.episode_end_status = "rerecord_episode"
|
||||||
|
|
||||||
|
# Reset episode status on button release
|
||||||
|
elif event.type == pygame.JOYBUTTONUP:
|
||||||
|
if event.button in [0, 2, 3]:
|
||||||
|
self.episode_end_status = None
|
||||||
|
|
||||||
|
# Check for RB button (typically button 5) for intervention flag
|
||||||
|
if self.joystick.get_button(5):
|
||||||
|
self.intervention_flag = True
|
||||||
|
else:
|
||||||
|
self.intervention_flag = False
|
||||||
|
|
||||||
|
def get_deltas(self):
|
||||||
|
"""Get the current movement deltas from gamepad state."""
|
||||||
|
import pygame
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read joystick axes
|
||||||
|
# Left stick X and Y (typically axes 0 and 1)
|
||||||
|
x_input = self.joystick.get_axis(0) # Left/Right
|
||||||
|
y_input = self.joystick.get_axis(1) # Up/Down (often inverted)
|
||||||
|
|
||||||
|
# Right stick Y (typically axis 3 or 4)
|
||||||
|
z_input = self.joystick.get_axis(3) # Up/Down for Z
|
||||||
|
|
||||||
|
# Apply deadzone to avoid drift
|
||||||
|
x_input = 0 if abs(x_input) < self.deadzone else x_input
|
||||||
|
y_input = 0 if abs(y_input) < self.deadzone else y_input
|
||||||
|
z_input = 0 if abs(z_input) < self.deadzone else z_input
|
||||||
|
|
||||||
|
# Calculate deltas (note: may need to invert axes depending on controller)
|
||||||
|
delta_x = -y_input * self.y_step_size # Forward/backward
|
||||||
|
delta_y = -x_input * self.x_step_size # Left/right
|
||||||
|
delta_z = -z_input * self.z_step_size # Up/down
|
||||||
|
|
||||||
|
return delta_x, delta_y, delta_z
|
||||||
|
|
||||||
|
except pygame.error:
|
||||||
|
logging.error("Error reading gamepad. Is it still connected?")
|
||||||
|
return 0.0, 0.0, 0.0
|
||||||
|
|
||||||
|
def should_intervene(self):
|
||||||
|
"""Return True if intervention flag was set."""
|
||||||
|
return self.intervention_flag
|
||||||
|
|
||||||
|
|
||||||
|
class GamepadControllerHID(InputController):
|
||||||
|
"""Generate motion deltas from gamepad input using HIDAPI."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
x_step_size=0.01,
|
||||||
|
y_step_size=0.01,
|
||||||
|
z_step_size=0.01,
|
||||||
|
deadzone=0.1,
|
||||||
|
vendor_id=0x046D,
|
||||||
|
product_id=0xC219,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the HID gamepad controller.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_size: Base movement step size in meters
|
||||||
|
z_scale: Scaling factor for Z-axis movement
|
||||||
|
deadzone: Joystick deadzone to prevent drift
|
||||||
|
vendor_id: USB vendor ID of the gamepad (default: Logitech)
|
||||||
|
product_id: USB product ID of the gamepad (default: RumblePad 2)
|
||||||
|
"""
|
||||||
|
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||||
|
self.deadzone = deadzone
|
||||||
|
self.vendor_id = vendor_id
|
||||||
|
self.product_id = product_id
|
||||||
|
self.device = None
|
||||||
|
self.device_info = None
|
||||||
|
|
||||||
|
# Movement values (normalized from -1.0 to 1.0)
|
||||||
|
self.left_x = 0.0
|
||||||
|
self.left_y = 0.0
|
||||||
|
self.right_x = 0.0
|
||||||
|
self.right_y = 0.0
|
||||||
|
|
||||||
|
# Button states
|
||||||
|
self.buttons = {}
|
||||||
|
self.quit_requested = False
|
||||||
|
self.save_requested = False
|
||||||
|
self.intervention_flag = False
|
||||||
|
|
||||||
|
def find_device(self):
|
||||||
|
"""Look for the gamepad device by vendor and product ID."""
|
||||||
|
import hid
|
||||||
|
|
||||||
|
devices = hid.enumerate()
|
||||||
|
for device in devices:
|
||||||
|
if (
|
||||||
|
device["vendor_id"] == self.vendor_id
|
||||||
|
and device["product_id"] == self.product_id
|
||||||
|
):
|
||||||
|
logging.info(
|
||||||
|
f"Found gamepad: {device.get('product_string', 'Unknown')}"
|
||||||
|
)
|
||||||
|
return device
|
||||||
|
|
||||||
|
logging.error(
|
||||||
|
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and "
|
||||||
|
f"product ID 0x{self.product_id:04X} found"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Connect to the gamepad using HIDAPI."""
|
||||||
|
import hid
|
||||||
|
|
||||||
|
self.device_info = self.find_device()
|
||||||
|
if not self.device_info:
|
||||||
|
self.running = False
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.info(f"Connecting to gamepad at path: {self.device_info['path']}")
|
||||||
|
self.device = hid.device()
|
||||||
|
self.device.open_path(self.device_info["path"])
|
||||||
|
self.device.set_nonblocking(1)
|
||||||
|
|
||||||
|
manufacturer = self.device.get_manufacturer_string()
|
||||||
|
product = self.device.get_product_string()
|
||||||
|
logging.info(f"Connected to {manufacturer} {product}")
|
||||||
|
|
||||||
|
logging.info("Gamepad controls (HID mode):")
|
||||||
|
logging.info(" Left analog stick: Move in X-Y plane")
|
||||||
|
logging.info(" Right analog stick: Move in Z axis (vertical)")
|
||||||
|
logging.info(" Button 1/B/Circle: Exit")
|
||||||
|
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
|
||||||
|
logging.info(" Button 3/X/Square: End episode with FAILURE")
|
||||||
|
|
||||||
|
except OSError as e:
|
||||||
|
logging.error(f"Error opening gamepad: {e}")
|
||||||
|
logging.error(
|
||||||
|
"You might need to run this with sudo/admin privileges on some systems"
|
||||||
|
)
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Close the HID device connection."""
|
||||||
|
if self.device:
|
||||||
|
self.device.close()
|
||||||
|
self.device = None
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
"""
|
||||||
|
Read and process the latest gamepad data.
|
||||||
|
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
|
||||||
|
"""
|
||||||
|
for _ in range(10):
|
||||||
|
self._update()
|
||||||
|
|
||||||
|
def _update(self):
|
||||||
|
"""Read and process the latest gamepad data."""
|
||||||
|
if not self.device or not self.running:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read data from the gamepad
|
||||||
|
data = self.device.read(64)
|
||||||
|
if data:
|
||||||
|
# Interpret gamepad data - this will vary by controller model
|
||||||
|
# These offsets are for the Logitech RumblePad 2
|
||||||
|
if len(data) >= 8:
|
||||||
|
# Normalize joystick values from 0-255 to -1.0-1.0
|
||||||
|
self.left_x = (data[1] - 128) / 128.0
|
||||||
|
self.left_y = (data[2] - 128) / 128.0
|
||||||
|
self.right_x = (data[3] - 128) / 128.0
|
||||||
|
self.right_y = (data[4] - 128) / 128.0
|
||||||
|
|
||||||
|
# Apply deadzone
|
||||||
|
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||||
|
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||||
|
self.right_x = (
|
||||||
|
0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||||
|
)
|
||||||
|
self.right_y = (
|
||||||
|
0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse button states (byte 5 in the Logitech RumblePad 2)
|
||||||
|
buttons = data[5]
|
||||||
|
|
||||||
|
# Check if RB is pressed then the intervention flag should be set
|
||||||
|
self.intervention_flag = data[6] == 2
|
||||||
|
|
||||||
|
# Check if Y/Triangle button (bit 7) is pressed for saving
|
||||||
|
# Check if X/Square button (bit 5) is pressed for failure
|
||||||
|
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||||
|
if buttons & 1 << 7:
|
||||||
|
self.episode_end_status = "success"
|
||||||
|
elif buttons & 1 << 5:
|
||||||
|
self.episode_end_status = "failure"
|
||||||
|
elif buttons & 1 << 4:
|
||||||
|
self.episode_end_status = "rerecord_episode"
|
||||||
|
else:
|
||||||
|
self.episode_end_status = None
|
||||||
|
|
||||||
|
except OSError as e:
|
||||||
|
logging.error(f"Error reading from gamepad: {e}")
|
||||||
|
|
||||||
|
def get_deltas(self):
|
||||||
|
"""Get the current movement deltas from gamepad state."""
|
||||||
|
# Calculate deltas - invert as needed based on controller orientation
|
||||||
|
delta_x = -self.left_y * self.x_step_size # Forward/backward
|
||||||
|
delta_y = -self.left_x * self.y_step_size # Left/right
|
||||||
|
delta_z = -self.right_y * self.z_step_size # Up/down
|
||||||
|
|
||||||
|
return delta_x, delta_y, delta_z
|
||||||
|
|
||||||
|
def should_quit(self):
|
||||||
|
"""Return True if quit button was pressed."""
|
||||||
|
return self.quit_requested
|
||||||
|
|
||||||
|
def should_save(self):
|
||||||
|
"""Return True if save button was pressed."""
|
||||||
|
return self.save_requested
|
||||||
|
|
||||||
|
def should_intervene(self):
|
||||||
|
"""Return True if intervention flag was set."""
|
||||||
|
return self.intervention_flag
|
||||||
|
|
||||||
|
|
||||||
|
def test_forward_kinematics(robot, fps=10):
|
||||||
|
logging.info("Testing Forward Kinematics")
|
||||||
|
timestep = time.perf_counter()
|
||||||
|
while time.perf_counter() - timestep < 60.0:
|
||||||
|
loop_start_time = time.perf_counter()
|
||||||
|
robot.teleop_step()
|
||||||
|
obs = robot.capture_observation()
|
||||||
|
joint_positions = obs["observation.state"].cpu().numpy()
|
||||||
|
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
|
||||||
|
logging.info(f"EE Position: {ee_pos[:3,3]}")
|
||||||
|
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||||
|
|
||||||
|
|
||||||
|
def test_inverse_kinematics(robot, fps=10):
|
||||||
|
logging.info("Testing Inverse Kinematics")
|
||||||
|
timestep = time.perf_counter()
|
||||||
|
while time.perf_counter() - timestep < 60.0:
|
||||||
|
loop_start_time = time.perf_counter()
|
||||||
|
obs = robot.capture_observation()
|
||||||
|
joint_positions = obs["observation.state"].cpu().numpy()
|
||||||
|
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
|
||||||
|
desired_ee_pos = ee_pos
|
||||||
|
target_joint_state = RobotKinematics.ik(
|
||||||
|
joint_positions, desired_ee_pos, position_only=True
|
||||||
|
)
|
||||||
|
robot.send_action(torch.from_numpy(target_joint_state))
|
||||||
|
logging.info(f"Target Joint State: {target_joint_state}")
|
||||||
|
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||||
|
|
||||||
|
|
||||||
|
def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
|
||||||
|
logging.info("Testing Inverse Kinematics")
|
||||||
|
fk_func = RobotKinematics.fk_gripper_tip
|
||||||
|
timestep = time.perf_counter()
|
||||||
|
while time.perf_counter() - timestep < 60.0:
|
||||||
|
loop_start_time = time.perf_counter()
|
||||||
|
obs = robot.capture_observation()
|
||||||
|
joint_positions = obs["observation.state"].cpu().numpy()
|
||||||
|
ee_pos = fk_func(joint_positions)
|
||||||
|
|
||||||
|
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||||
|
leader_ee = fk_func(leader_joint_positions)
|
||||||
|
|
||||||
|
desired_ee_pos = leader_ee
|
||||||
|
target_joint_state = RobotKinematics.ik(
|
||||||
|
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
|
||||||
|
)
|
||||||
|
robot.send_action(torch.from_numpy(target_joint_state))
|
||||||
|
logging.info(f"Leader EE: {leader_ee[:3,3]}, Follower EE: {ee_pos[:3,3]}")
|
||||||
|
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||||
|
|
||||||
|
|
||||||
|
def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
|
||||||
|
logging.info("Testing Delta End-Effector Control")
|
||||||
|
timestep = time.perf_counter()
|
||||||
|
|
||||||
|
# Initial position capture
|
||||||
|
obs = robot.capture_observation()
|
||||||
|
joint_positions = obs["observation.state"].cpu().numpy()
|
||||||
|
|
||||||
|
fk_func = RobotKinematics.fk_gripper_tip
|
||||||
|
|
||||||
|
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||||
|
initial_leader_ee = fk_func(leader_joint_positions)
|
||||||
|
|
||||||
|
desired_ee_pos = np.diag(np.ones(4))
|
||||||
|
|
||||||
|
while time.perf_counter() - timestep < 60.0:
|
||||||
|
loop_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Get leader state for teleoperation
|
||||||
|
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||||
|
leader_ee = fk_func(leader_joint_positions)
|
||||||
|
|
||||||
|
# Get current state
|
||||||
|
# obs = robot.capture_observation()
|
||||||
|
# joint_positions = obs["observation.state"].cpu().numpy()
|
||||||
|
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||||
|
current_ee_pos = fk_func(joint_positions)
|
||||||
|
|
||||||
|
# Calculate delta between leader and follower end-effectors
|
||||||
|
# Scaling factor can be adjusted for sensitivity
|
||||||
|
scaling_factor = 1.0
|
||||||
|
ee_delta = (leader_ee - initial_leader_ee) * scaling_factor
|
||||||
|
|
||||||
|
# Apply delta to current position
|
||||||
|
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + ee_delta[0, 3]
|
||||||
|
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + ee_delta[1, 3]
|
||||||
|
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + ee_delta[2, 3]
|
||||||
|
|
||||||
|
if np.any(np.abs(ee_delta[:3, 3]) > 0.01):
|
||||||
|
# Compute joint targets via inverse kinematics
|
||||||
|
target_joint_state = RobotKinematics.ik(
|
||||||
|
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_leader_ee = leader_ee.copy()
|
||||||
|
|
||||||
|
# Send command to robot
|
||||||
|
robot.send_action(torch.from_numpy(target_joint_state))
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
logging.info(
|
||||||
|
f"Current EE: {current_ee_pos[:3,3]}, Desired EE: {desired_ee_pos[:3,3]}"
|
||||||
|
)
|
||||||
|
logging.info(f"Delta EE: {ee_delta[:3,3]}")
|
||||||
|
|
||||||
|
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||||
|
|
||||||
|
|
||||||
|
def teleoperate_delta_inverse_kinematics(
|
||||||
|
robot, controller, fps=10, bounds=None, fk_func=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Control a robot using delta end-effector movements from any input controller.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot: Robot instance to control
|
||||||
|
controller: InputController instance (keyboard, gamepad, etc.)
|
||||||
|
fps: Control frequency in Hz
|
||||||
|
bounds: Optional position limits
|
||||||
|
fk_func: Forward kinematics function to use
|
||||||
|
"""
|
||||||
|
if fk_func is None:
|
||||||
|
fk_func = RobotKinematics.fk_gripper_tip
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Testing Delta End-Effector Control with {controller.__class__.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initial position capture
|
||||||
|
obs = robot.capture_observation()
|
||||||
|
joint_positions = obs["observation.state"].cpu().numpy()
|
||||||
|
current_ee_pos = fk_func(joint_positions)
|
||||||
|
|
||||||
|
# Initialize desired position with current position
|
||||||
|
desired_ee_pos = np.eye(4) # Identity matrix
|
||||||
|
|
||||||
|
timestep = time.perf_counter()
|
||||||
|
with controller:
|
||||||
|
while not controller.should_quit() and time.perf_counter() - timestep < 60.0:
|
||||||
|
loop_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Process input events
|
||||||
|
controller.update()
|
||||||
|
|
||||||
|
# Get currrent robot state
|
||||||
|
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||||
|
current_ee_pos = fk_func(joint_positions)
|
||||||
|
|
||||||
|
# Get movement deltas from the controller
|
||||||
|
delta_x, delta_y, delta_z = controller.get_deltas()
|
||||||
|
|
||||||
|
# Update desired position
|
||||||
|
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + delta_x
|
||||||
|
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + delta_y
|
||||||
|
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + delta_z
|
||||||
|
|
||||||
|
# Apply bounds if provided
|
||||||
|
if bounds is not None:
|
||||||
|
desired_ee_pos[:3, 3] = np.clip(
|
||||||
|
desired_ee_pos[:3, 3], bounds["min"], bounds["max"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only send commands if there's actual movement
|
||||||
|
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
|
||||||
|
# Compute joint targets via inverse kinematics
|
||||||
|
target_joint_state = RobotKinematics.ik(
|
||||||
|
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send command to robot
|
||||||
|
robot.send_action(torch.from_numpy(target_joint_state))
|
||||||
|
|
||||||
|
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||||
|
|
||||||
|
|
||||||
|
def teleoperate_gym_env(env, controller, fps: int = 30):
|
||||||
|
"""
|
||||||
|
Control a robot through a gym environment using keyboard inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: A gym environment created with make_robot_env
|
||||||
|
fps: Target control frequency
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging.info("Testing Keyboard Control of Gym Environment")
|
||||||
|
print("Keyboard controls:")
|
||||||
|
print(" Arrow keys: Move in X-Y plane")
|
||||||
|
print(" Shift and Shift_R: Move in Z axis")
|
||||||
|
print(" ESC: Exit")
|
||||||
|
|
||||||
|
# Reset the environment to get initial observation
|
||||||
|
obs, info = env.reset()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with controller:
|
||||||
|
while not controller.should_quit():
|
||||||
|
loop_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Process input events
|
||||||
|
controller.update()
|
||||||
|
|
||||||
|
# Get movement deltas from the controller
|
||||||
|
delta_x, delta_y, delta_z = controller.get_deltas()
|
||||||
|
|
||||||
|
# Create the action vector
|
||||||
|
action = np.array([delta_x, delta_y, delta_z])
|
||||||
|
|
||||||
|
# Skip if no movement
|
||||||
|
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
|
||||||
|
# Step the environment - pass action as a tensor with intervention flag
|
||||||
|
action_tensor = torch.from_numpy(action.astype(np.float32))
|
||||||
|
obs, reward, terminated, truncated, info = env.step(
|
||||||
|
(action_tensor, False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log information
|
||||||
|
logging.info(
|
||||||
|
f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]"
|
||||||
|
)
|
||||||
|
logging.info(f"Reward: {reward}")
|
||||||
|
|
||||||
|
# Reset if episode ended
|
||||||
|
if terminated or truncated:
|
||||||
|
logging.info("Episode ended, resetting environment")
|
||||||
|
obs, info = env.reset()
|
||||||
|
|
||||||
|
# Maintain target frame rate
|
||||||
|
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Close the environment
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
def make_robot_from_config(config_path, overrides=None):
|
||||||
|
"""Helper function to create a robot from a config file."""
|
||||||
|
if overrides is None:
|
||||||
|
overrides = []
|
||||||
|
robot_cfg = init_hydra_config(config_path, overrides)
|
||||||
|
return make_robot(robot_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
type=str,
|
||||||
|
default="keyboard",
|
||||||
|
choices=[
|
||||||
|
"keyboard",
|
||||||
|
"gamepad",
|
||||||
|
"keyboard_gym",
|
||||||
|
"gamepad_gym",
|
||||||
|
"leader",
|
||||||
|
"leader_abs",
|
||||||
|
],
|
||||||
|
help="Control mode to use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task",
|
||||||
|
type=str,
|
||||||
|
default="Robot manipulation task",
|
||||||
|
help="Description of the task being performed",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--push-to-hub",
|
||||||
|
default=True,
|
||||||
|
type=bool,
|
||||||
|
help="Push the dataset to Hugging Face Hub",
|
||||||
|
)
|
||||||
|
# Add the rest of your existing arguments
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
robot = make_robot_from_config("lerobot/configs/robot/so100.yaml", [])
|
||||||
|
|
||||||
|
if not robot.is_connected:
|
||||||
|
robot.connect()
|
||||||
|
|
||||||
|
# Example bounds
|
||||||
|
bounds = {
|
||||||
|
"max": np.array([0.32170487, 0.201285, 0.10273342]),
|
||||||
|
"min": np.array([0.16631757, -0.08237468, 0.03364977]),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Determine controller type based on mode prefix
|
||||||
|
controller = None
|
||||||
|
if args.mode.startswith("keyboard"):
|
||||||
|
controller = KeyboardController(
|
||||||
|
x_step_size=0.01, y_step_size=0.01, z_step_size=0.05
|
||||||
|
)
|
||||||
|
elif args.mode.startswith("gamepad"):
|
||||||
|
controller = GamepadController(
|
||||||
|
x_step_size=0.02, y_step_size=0.02, z_step_size=0.05
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle mode categories
|
||||||
|
if args.mode in ["keyboard", "gamepad"]:
|
||||||
|
# Direct robot control modes
|
||||||
|
teleoperate_delta_inverse_kinematics(
|
||||||
|
robot, controller, bounds=bounds, fps=10
|
||||||
|
)
|
||||||
|
|
||||||
|
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
|
||||||
|
# Gym environment control modes
|
||||||
|
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||||
|
|
||||||
|
cfg = init_hydra_config("lerobot/configs/env/so100_real.yaml", [])
|
||||||
|
cfg.env.wrapper.ee_action_space_params.use_gamepad = False
|
||||||
|
env = make_robot_env(robot, None, cfg)
|
||||||
|
teleoperate_gym_env(env, controller)
|
||||||
|
|
||||||
|
elif args.mode == "leader":
|
||||||
|
# Leader-follower modes don't use controllers
|
||||||
|
teleoperate_delta_inverse_kinematics_with_leader(robot)
|
||||||
|
|
||||||
|
elif args.mode == "leader_abs":
|
||||||
|
teleoperate_inverse_kinematics_with_leader(robot)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if robot.is_connected:
|
||||||
|
robot.disconnect()
|
|
@ -7,25 +7,26 @@ import numpy as np
|
||||||
from lerobot.common.robot_devices.control_utils import is_headless
|
from lerobot.common.robot_devices.control_utils import is_headless
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||||
|
|
||||||
|
|
||||||
def find_joint_bounds(
|
def find_joint_bounds(
|
||||||
robot,
|
robot,
|
||||||
control_time_s=20,
|
control_time_s=30,
|
||||||
display_cameras=False,
|
display_cameras=False,
|
||||||
):
|
):
|
||||||
# TODO(rcadene): Add option to record logs
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|
||||||
control_time_s = float("inf")
|
|
||||||
|
|
||||||
timestamp = 0
|
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
pos_list = []
|
pos_list = []
|
||||||
while timestamp < control_time_s:
|
while True:
|
||||||
observation, action = robot.teleop_step(record_data=True)
|
observation, action = robot.teleop_step(record_data=True)
|
||||||
|
|
||||||
|
# Wait for 5 seconds to stabilize the robot initial position
|
||||||
|
if time.perf_counter() - start_episode_t < 5:
|
||||||
|
continue
|
||||||
|
|
||||||
pos_list.append(robot.follower_arms["main"].read("Present_Position"))
|
pos_list.append(robot.follower_arms["main"].read("Present_Position"))
|
||||||
|
|
||||||
if display_cameras and not is_headless():
|
if display_cameras and not is_headless():
|
||||||
|
@ -36,8 +37,7 @@ def find_joint_bounds(
|
||||||
)
|
)
|
||||||
cv2.waitKey(1)
|
cv2.waitKey(1)
|
||||||
|
|
||||||
timestamp = time.perf_counter() - start_episode_t
|
if time.perf_counter() - start_episode_t > control_time_s:
|
||||||
if timestamp > 60:
|
|
||||||
max = np.max(np.stack(pos_list), 0)
|
max = np.max(np.stack(pos_list), 0)
|
||||||
min = np.min(np.stack(pos_list), 0)
|
min = np.min(np.stack(pos_list), 0)
|
||||||
print(f"Max angle position per joint {max}")
|
print(f"Max angle position per joint {max}")
|
||||||
|
@ -45,6 +45,43 @@ def find_joint_bounds(
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def find_ee_bounds(
|
||||||
|
robot,
|
||||||
|
control_time_s=30,
|
||||||
|
display_cameras=False,
|
||||||
|
):
|
||||||
|
if not robot.is_connected:
|
||||||
|
robot.connect()
|
||||||
|
|
||||||
|
start_episode_t = time.perf_counter()
|
||||||
|
ee_list = []
|
||||||
|
while True:
|
||||||
|
observation, action = robot.teleop_step(record_data=True)
|
||||||
|
|
||||||
|
# Wait for 5 seconds to stabilize the robot initial position
|
||||||
|
if time.perf_counter() - start_episode_t < 5:
|
||||||
|
continue
|
||||||
|
|
||||||
|
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||||
|
print(f"Joint positions: {joint_positions}")
|
||||||
|
ee_list.append(RobotKinematics.fk_gripper_tip(joint_positions)[:3, 3])
|
||||||
|
|
||||||
|
if display_cameras and not is_headless():
|
||||||
|
image_keys = [key for key in observation if "image" in key]
|
||||||
|
for key in image_keys:
|
||||||
|
cv2.imshow(
|
||||||
|
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||||
|
)
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
|
if time.perf_counter() - start_episode_t > control_time_s:
|
||||||
|
max = np.max(np.stack(ee_list), 0)
|
||||||
|
min = np.min(np.stack(ee_list), 0)
|
||||||
|
print(f"Max ee position {max}")
|
||||||
|
print(f"Min ee position {min}")
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -59,14 +96,26 @@ if __name__ == "__main__":
|
||||||
nargs="*",
|
nargs="*",
|
||||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
type=str,
|
||||||
|
default="joint",
|
||||||
|
choices=["joint", "ee"],
|
||||||
|
help="Mode to run the script in. Can be 'joint' or 'ee'.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--control-time-s",
|
"--control-time-s",
|
||||||
type=float,
|
type=int,
|
||||||
default=20,
|
default=30,
|
||||||
help="Maximum episode length in seconds",
|
help="Time step to use for control.",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||||
|
|
||||||
robot = make_robot(robot_cfg)
|
robot = make_robot(robot_cfg)
|
||||||
find_joint_bounds(robot, control_time_s=args.control_time_s)
|
if args.mode == "joint":
|
||||||
|
find_joint_bounds(robot, args.control_time_s)
|
||||||
|
elif args.mode == "ee":
|
||||||
|
find_ee_bounds(robot, args.control_time_s)
|
||||||
|
if robot.is_connected:
|
||||||
|
robot.disconnect()
|
||||||
|
|
|
@ -1,19 +1,26 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
|
from typing import Annotated, Any, Dict, Tuple
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms.functional as F # noqa: N812
|
import torchvision.transforms.functional as F # noqa: N812
|
||||||
|
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless
|
from lerobot.common.robot_devices.control_utils import (
|
||||||
|
busy_wait,
|
||||||
|
is_headless,
|
||||||
|
reset_follower_position,
|
||||||
|
)
|
||||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
from lerobot.common.utils.utils import init_hydra_config, log_say
|
from lerobot.common.utils.utils import init_hydra_config, log_say
|
||||||
|
|
||||||
|
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
@ -76,13 +83,19 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
|
|
||||||
# Retrieve the size of the joint position interval bound.
|
# Retrieve the size of the joint position interval bound.
|
||||||
self.relative_bounds_size = (
|
self.relative_bounds_size = (
|
||||||
|
(
|
||||||
self.robot.config.joint_position_relative_bounds["max"]
|
self.robot.config.joint_position_relative_bounds["max"]
|
||||||
- self.robot.config.joint_position_relative_bounds["min"]
|
- self.robot.config.joint_position_relative_bounds["min"]
|
||||||
)
|
)
|
||||||
|
if self.robot.config.joint_position_relative_bounds is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
self.delta_relative_bounds_size = self.relative_bounds_size * self.delta
|
self.robot.config.max_relative_target = (
|
||||||
|
self.relative_bounds_size.float()
|
||||||
self.robot.config.max_relative_target = self.delta_relative_bounds_size.float()
|
if self.relative_bounds_size is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
# Dynamically configure the observation and action spaces.
|
# Dynamically configure the observation and action spaces.
|
||||||
self._setup_spaces()
|
self._setup_spaces()
|
||||||
|
@ -99,26 +112,23 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
- The action space is defined as a Tuple where:
|
- The action space is defined as a Tuple where:
|
||||||
• The first element is a Box space representing joint position commands. It is defined as relative (delta)
|
• The first element is a Box space representing joint position commands. It is defined as relative (delta)
|
||||||
or absolute, based on the configuration.
|
or absolute, based on the configuration.
|
||||||
• The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation).
|
• ThE SECONd element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation).
|
||||||
"""
|
"""
|
||||||
example_obs = self.robot.capture_observation()
|
example_obs = self.robot.capture_observation()
|
||||||
|
|
||||||
# Define observation spaces for images and other states.
|
# Define observation spaces for images and other states.
|
||||||
image_keys = [key for key in example_obs if "image" in key]
|
image_keys = [key for key in example_obs if "image" in key]
|
||||||
state_keys = [key for key in example_obs if "image" not in key]
|
|
||||||
observation_spaces = {
|
observation_spaces = {
|
||||||
key: gym.spaces.Box(
|
key: gym.spaces.Box(
|
||||||
low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8
|
low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8
|
||||||
)
|
)
|
||||||
for key in image_keys
|
for key in image_keys
|
||||||
}
|
}
|
||||||
observation_spaces["observation.state"] = gym.spaces.Dict(
|
observation_spaces["observation.state"] = gym.spaces.Box(
|
||||||
{
|
low=0,
|
||||||
key: gym.spaces.Box(
|
high=10,
|
||||||
low=0, high=10, shape=example_obs[key].shape, dtype=np.float32
|
shape=example_obs["observation.state"].shape,
|
||||||
)
|
dtype=np.float32,
|
||||||
for key in state_keys
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.observation_space = gym.spaces.Dict(observation_spaces)
|
self.observation_space = gym.spaces.Dict(observation_spaces)
|
||||||
|
@ -126,20 +136,31 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
# Define the action space for joint positions along with setting an intervention flag.
|
# Define the action space for joint positions along with setting an intervention flag.
|
||||||
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
|
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
|
||||||
if self.use_delta_action_space:
|
if self.use_delta_action_space:
|
||||||
|
bounds = (
|
||||||
|
self.relative_bounds_size
|
||||||
|
if self.relative_bounds_size is not None
|
||||||
|
else np.ones(action_dim) * 1000
|
||||||
|
)
|
||||||
action_space_robot = gym.spaces.Box(
|
action_space_robot = gym.spaces.Box(
|
||||||
low=-self.relative_bounds_size.cpu().numpy(),
|
low=-bounds,
|
||||||
high=self.relative_bounds_size.cpu().numpy(),
|
high=bounds,
|
||||||
shape=(action_dim,),
|
shape=(action_dim,),
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
bounds_min = (
|
||||||
|
self.robot.config.joint_position_relative_bounds["min"].cpu().numpy()
|
||||||
|
if self.robot.config.joint_position_relative_bounds is not None
|
||||||
|
else np.ones(action_dim) * -1000
|
||||||
|
)
|
||||||
|
bounds_max = (
|
||||||
|
self.robot.config.joint_position_relative_bounds["max"].cpu().numpy()
|
||||||
|
if self.robot.config.joint_position_relative_bounds is not None
|
||||||
|
else np.ones(action_dim) * 1000
|
||||||
|
)
|
||||||
action_space_robot = gym.spaces.Box(
|
action_space_robot = gym.spaces.Box(
|
||||||
low=self.robot.config.joint_position_relative_bounds["min"]
|
low=bounds_min,
|
||||||
.cpu()
|
high=bounds_max,
|
||||||
.numpy(),
|
|
||||||
high=self.robot.config.joint_position_relative_bounds["max"]
|
|
||||||
.cpu()
|
|
||||||
.numpy(),
|
|
||||||
shape=(action_dim,),
|
shape=(action_dim,),
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
@ -176,7 +197,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
self.episode_data = None
|
self.episode_data = None
|
||||||
|
|
||||||
return observation, {"initial_position": self.initial_follower_position}
|
return observation, {}
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, action: Tuple[np.ndarray, bool]
|
self, action: Tuple[np.ndarray, bool]
|
||||||
|
@ -218,6 +239,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
policy_action = np.clip(
|
policy_action = np.clip(
|
||||||
policy_action, self.action_space[0].low, self.action_space[0].high
|
policy_action, self.action_space[0].low, self.action_space[0].high
|
||||||
)
|
)
|
||||||
|
|
||||||
if not intervention_bool:
|
if not intervention_bool:
|
||||||
if self.use_delta_action_space:
|
if self.use_delta_action_space:
|
||||||
target_joint_positions = (
|
target_joint_positions = (
|
||||||
|
@ -238,8 +260,9 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
teleop_action = (
|
teleop_action = (
|
||||||
teleop_action - self.current_joint_positions
|
teleop_action - self.current_joint_positions
|
||||||
) / self.delta
|
) / self.delta
|
||||||
if torch.any(teleop_action < -self.relative_bounds_size) and torch.any(
|
if self.relative_bounds_size is not None and (
|
||||||
teleop_action > self.relative_bounds_size
|
torch.any(teleop_action < -self.relative_bounds_size)
|
||||||
|
and torch.any(teleop_action > self.relative_bounds_size)
|
||||||
):
|
):
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Relative teleop delta exceeded bounds {self.relative_bounds_size}, teleop_action {teleop_action}\n"
|
f"Relative teleop delta exceeded bounds {self.relative_bounds_size}, teleop_action {teleop_action}\n"
|
||||||
|
@ -299,6 +322,46 @@ class HILSerlRobotEnv(gym.Env):
|
||||||
self.robot.disconnect()
|
self.robot.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
class AddJointVelocityToObservation(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env, joint_velocity_limits=100.0, fps=30):
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
# Extend observation space to include joint velocities
|
||||||
|
old_low = self.observation_space["observation.state"].low
|
||||||
|
old_high = self.observation_space["observation.state"].high
|
||||||
|
old_shape = self.observation_space["observation.state"].shape
|
||||||
|
|
||||||
|
self.last_joint_positions = np.zeros(old_shape)
|
||||||
|
|
||||||
|
new_low = np.concatenate(
|
||||||
|
[old_low, np.ones_like(old_low) * -joint_velocity_limits]
|
||||||
|
)
|
||||||
|
new_high = np.concatenate(
|
||||||
|
[old_high, np.ones_like(old_high) * joint_velocity_limits]
|
||||||
|
)
|
||||||
|
|
||||||
|
new_shape = (old_shape[0] * 2,)
|
||||||
|
|
||||||
|
self.observation_space["observation.state"] = gym.spaces.Box(
|
||||||
|
low=new_low,
|
||||||
|
high=new_high,
|
||||||
|
shape=new_shape,
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dt = 1.0 / fps
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
joint_velocities = (
|
||||||
|
observation["observation.state"] - self.last_joint_positions
|
||||||
|
) / self.dt
|
||||||
|
self.last_joint_positions = observation["observation.state"].clone()
|
||||||
|
observation["observation.state"] = torch.cat(
|
||||||
|
[observation["observation.state"], joint_velocities], dim=-1
|
||||||
|
)
|
||||||
|
return observation
|
||||||
|
|
||||||
|
|
||||||
class ActionRepeatWrapper(gym.Wrapper):
|
class ActionRepeatWrapper(gym.Wrapper):
|
||||||
def __init__(self, env, nb_repeat: int = 1):
|
def __init__(self, env, nb_repeat: int = 1):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
@ -347,8 +410,6 @@ class RewardWrapper(gym.Wrapper):
|
||||||
)
|
)
|
||||||
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
||||||
|
|
||||||
# logging.info(f"Reward: {reward}")
|
|
||||||
|
|
||||||
if reward == 1.0:
|
if reward == 1.0:
|
||||||
terminated = True
|
terminated = True
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
@ -465,9 +526,7 @@ class TimeLimitWrapper(gym.Wrapper):
|
||||||
if 1.0 / time_since_last_step < self.fps:
|
if 1.0 / time_since_last_step < self.fps:
|
||||||
logging.debug(f"Current timestep exceeded expected fps {self.fps}")
|
logging.debug(f"Current timestep exceeded expected fps {self.fps}")
|
||||||
|
|
||||||
if self.episode_time_in_s > self.control_time_s:
|
if self.current_step >= self.max_episode_steps:
|
||||||
# if self.current_step >= self.max_episode_steps:
|
|
||||||
# Terminated = True
|
|
||||||
terminated = True
|
terminated = True
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
@ -508,7 +567,20 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
for k in self.crop_params_dict:
|
for k in self.crop_params_dict:
|
||||||
device = obs[k].device
|
device = obs[k].device
|
||||||
|
if obs[k].dim() >= 3:
|
||||||
|
# Reshape to combine height and width dimensions for easier calculation
|
||||||
|
batch_size = obs[k].size(0)
|
||||||
|
channels = obs[k].size(1)
|
||||||
|
flattened_spatial_dims = obs[k].view(batch_size, channels, -1)
|
||||||
|
|
||||||
|
# Calculate standard deviation across spatial dimensions (H, W)
|
||||||
|
std_per_channel = torch.std(flattened_spatial_dims, dim=2)
|
||||||
|
|
||||||
|
# If any channel has std=0, all pixels in that channel have the same value
|
||||||
|
if (std_per_channel <= 0.02).any():
|
||||||
|
logging.warning(
|
||||||
|
f"Potential hardware issue detected: All pixels have the same value in observation {k}"
|
||||||
|
)
|
||||||
# Check for NaNs before processing
|
# Check for NaNs before processing
|
||||||
if torch.isnan(obs[k]).any():
|
if torch.isnan(obs[k]).any():
|
||||||
logging.error(
|
logging.error(
|
||||||
|
@ -703,19 +775,21 @@ class ResetWrapper(gym.Wrapper):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: HILSerlRobotEnv,
|
env: HILSerlRobotEnv,
|
||||||
reset_fn: Optional[Callable[[], None]] = None,
|
reset_pose: np.ndarray | None = None,
|
||||||
reset_time_s: float = 5,
|
reset_time_s: float = 5,
|
||||||
):
|
):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.reset_fn = reset_fn
|
|
||||||
self.reset_time_s = reset_time_s
|
self.reset_time_s = reset_time_s
|
||||||
|
self.reset_pose = reset_pose
|
||||||
self.robot = self.unwrapped.robot
|
self.robot = self.unwrapped.robot
|
||||||
self.init_pos = self.unwrapped.initial_follower_position
|
|
||||||
|
|
||||||
def reset(self, *, seed=None, options=None):
|
def reset(self, *, seed=None, options=None):
|
||||||
if self.reset_fn is not None:
|
if self.reset_pose is not None:
|
||||||
self.reset_fn(self.env)
|
start_time = time.perf_counter()
|
||||||
|
log_say("Reset the environment.", play_sounds=True)
|
||||||
|
reset_follower_position(self.robot, self.reset_pose)
|
||||||
|
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
|
||||||
|
log_say("Reset the environment done.", play_sounds=True)
|
||||||
else:
|
else:
|
||||||
log_say(
|
log_say(
|
||||||
f"Manually reset the environment for {self.reset_time_s} seconds.",
|
f"Manually reset the environment for {self.reset_time_s} seconds.",
|
||||||
|
@ -741,10 +815,297 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||||
observation[key] = observation[key].unsqueeze(0)
|
observation[key] = observation[key].unsqueeze(0)
|
||||||
if "state" in key and observation[key].dim() == 1:
|
if "state" in key and observation[key].dim() == 1:
|
||||||
observation[key] = observation[key].unsqueeze(0)
|
observation[key] = observation[key].unsqueeze(0)
|
||||||
|
if "velocity" in key and observation[key].dim() == 1:
|
||||||
|
observation[key] = observation[key].unsqueeze(0)
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
# TODO: REMOVE TH
|
class EEActionWrapper(gym.ActionWrapper):
|
||||||
|
def __init__(self, env, ee_action_space_params=None):
|
||||||
|
super().__init__(env)
|
||||||
|
self.ee_action_space_params = ee_action_space_params
|
||||||
|
|
||||||
|
# Initialize kinematics instance for the appropriate robot type
|
||||||
|
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
|
||||||
|
self.kinematics = RobotKinematics(robot_type)
|
||||||
|
self.fk_function = self.kinematics.fk_gripper_tip
|
||||||
|
|
||||||
|
action_space_bounds = np.array(
|
||||||
|
[
|
||||||
|
ee_action_space_params.x_step_size,
|
||||||
|
ee_action_space_params.y_step_size,
|
||||||
|
ee_action_space_params.z_step_size,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
ee_action_space = gym.spaces.Box(
|
||||||
|
low=-action_space_bounds,
|
||||||
|
high=action_space_bounds,
|
||||||
|
shape=(3,),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
if isinstance(self.action_space, gym.spaces.Tuple):
|
||||||
|
self.action_space = gym.spaces.Tuple(
|
||||||
|
(ee_action_space, self.action_space[1])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.action_space = ee_action_space
|
||||||
|
|
||||||
|
self.bounds = ee_action_space_params.bounds
|
||||||
|
|
||||||
|
def action(self, action):
|
||||||
|
is_intervention = False
|
||||||
|
desired_ee_pos = np.eye(4)
|
||||||
|
if isinstance(action, tuple):
|
||||||
|
action, _ = action
|
||||||
|
|
||||||
|
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read(
|
||||||
|
"Present_Position"
|
||||||
|
)
|
||||||
|
current_ee_pos = self.fk_function(current_joint_pos)
|
||||||
|
if isinstance(action, torch.Tensor):
|
||||||
|
action = action.cpu().numpy()
|
||||||
|
desired_ee_pos[:3, 3] = np.clip(
|
||||||
|
current_ee_pos[:3, 3] + action,
|
||||||
|
self.bounds["min"],
|
||||||
|
self.bounds["max"],
|
||||||
|
)
|
||||||
|
target_joint_pos = self.kinematics.ik(
|
||||||
|
current_joint_pos,
|
||||||
|
desired_ee_pos,
|
||||||
|
position_only=True,
|
||||||
|
fk_func=self.fk_function,
|
||||||
|
)
|
||||||
|
return target_joint_pos, is_intervention
|
||||||
|
|
||||||
|
|
||||||
|
class EEObservationWrapper(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env, ee_pose_limits):
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
# Extend observation space to include end effector pose
|
||||||
|
prev_space = self.observation_space["observation.state"]
|
||||||
|
|
||||||
|
self.observation_space["observation.state"] = gym.spaces.Box(
|
||||||
|
low=np.concatenate([prev_space.low, ee_pose_limits["min"]]),
|
||||||
|
high=np.concatenate([prev_space.high, ee_pose_limits["max"]]),
|
||||||
|
shape=(prev_space.shape[0] + 3,),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize kinematics instance for the appropriate robot type
|
||||||
|
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
|
||||||
|
self.kinematics = RobotKinematics(robot_type)
|
||||||
|
self.fk_function = self.kinematics.fk_gripper_tip
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read(
|
||||||
|
"Present_Position"
|
||||||
|
)
|
||||||
|
current_ee_pos = self.fk_function(current_joint_pos)
|
||||||
|
observation["observation.state"] = torch.cat(
|
||||||
|
[
|
||||||
|
observation["observation.state"],
|
||||||
|
torch.from_numpy(current_ee_pos[:3, 3]),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
class GamepadControlWrapper(gym.Wrapper):
|
||||||
|
"""
|
||||||
|
Wrapper that allows controlling a gym environment with a gamepad.
|
||||||
|
|
||||||
|
This wrapper intercepts the step method and allows human input via gamepad
|
||||||
|
to override the agent's actions when desired.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env,
|
||||||
|
x_step_size=1.0,
|
||||||
|
y_step_size=1.0,
|
||||||
|
z_step_size=1.0,
|
||||||
|
auto_reset=False,
|
||||||
|
input_threshold=0.001,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the gamepad controller wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap
|
||||||
|
x_step_size: Base movement step size for X axis in meters
|
||||||
|
y_step_size: Base movement step size for Y axis in meters
|
||||||
|
z_step_size: Base movement step size for Z axis in meters
|
||||||
|
vendor_id: USB vendor ID of the gamepad (default: Logitech)
|
||||||
|
product_id: USB product ID of the gamepad (default: RumblePad 2)
|
||||||
|
auto_reset: Whether to auto reset the environment when episode ends
|
||||||
|
input_threshold: Minimum movement delta to consider as active input
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
from lerobot.scripts.server.end_effector_control_utils import (
|
||||||
|
GamepadControllerHID,
|
||||||
|
GamepadController,
|
||||||
|
)
|
||||||
|
|
||||||
|
# use HidApi for macos
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
self.controller = GamepadControllerHID(
|
||||||
|
x_step_size=x_step_size,
|
||||||
|
y_step_size=y_step_size,
|
||||||
|
z_step_size=z_step_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.controller = GamepadController(
|
||||||
|
x_step_size=x_step_size,
|
||||||
|
y_step_size=y_step_size,
|
||||||
|
z_step_size=z_step_size,
|
||||||
|
)
|
||||||
|
self.auto_reset = auto_reset
|
||||||
|
self.input_threshold = input_threshold
|
||||||
|
self.controller.start()
|
||||||
|
|
||||||
|
logging.info("Gamepad control wrapper initialized")
|
||||||
|
print("Gamepad controls:")
|
||||||
|
print(" Left analog stick: Move in X-Y plane")
|
||||||
|
print(" Right analog stick: Move in Z axis (up/down)")
|
||||||
|
print(" X/Square button: End episode (FAILURE)")
|
||||||
|
print(" Y/Triangle button: End episode (SUCCESS)")
|
||||||
|
print(" B/Circle button: Exit program")
|
||||||
|
|
||||||
|
def get_gamepad_action(self):
|
||||||
|
"""
|
||||||
|
Get the current action from the gamepad if any input is active.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_active, action, terminate_episode, success)
|
||||||
|
"""
|
||||||
|
# Update the controller to get fresh inputs
|
||||||
|
self.controller.update()
|
||||||
|
|
||||||
|
# Get movement deltas from the controller
|
||||||
|
delta_x, delta_y, delta_z = self.controller.get_deltas()
|
||||||
|
|
||||||
|
intervention_is_active = self.controller.should_intervene()
|
||||||
|
|
||||||
|
# Create action from gamepad input
|
||||||
|
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
|
||||||
|
|
||||||
|
# Check episode ending buttons
|
||||||
|
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
|
||||||
|
episode_end_status = self.controller.get_episode_end_status()
|
||||||
|
terminate_episode = episode_end_status is not None
|
||||||
|
success = episode_end_status == "success"
|
||||||
|
rerecord_episode = episode_end_status == "rerecord_episode"
|
||||||
|
|
||||||
|
return (
|
||||||
|
intervention_is_active,
|
||||||
|
gamepad_action,
|
||||||
|
terminate_episode,
|
||||||
|
success,
|
||||||
|
rerecord_episode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
"""
|
||||||
|
Step the environment, using gamepad input to override actions when active.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: Original action from agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
observation, reward, terminated, truncated, info
|
||||||
|
"""
|
||||||
|
# Get gamepad state and action
|
||||||
|
(
|
||||||
|
is_intervention,
|
||||||
|
gamepad_action,
|
||||||
|
terminate_episode,
|
||||||
|
success,
|
||||||
|
rerecord_episode,
|
||||||
|
) = self.get_gamepad_action()
|
||||||
|
|
||||||
|
# Update episode ending state if requested
|
||||||
|
if terminate_episode:
|
||||||
|
logging.info(
|
||||||
|
f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only override the action if gamepad is active
|
||||||
|
if is_intervention:
|
||||||
|
# Format according to the expected action type
|
||||||
|
if isinstance(self.action_space, gym.spaces.Tuple):
|
||||||
|
# For environments that use (action, is_intervention) tuples
|
||||||
|
final_action = (torch.from_numpy(gamepad_action), False)
|
||||||
|
else:
|
||||||
|
final_action = torch.from_numpy(gamepad_action)
|
||||||
|
else:
|
||||||
|
# Use the original action
|
||||||
|
final_action = action
|
||||||
|
|
||||||
|
# Step the environment
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(final_action)
|
||||||
|
|
||||||
|
# Add episode ending if requested via gamepad
|
||||||
|
terminated = terminated or truncated or terminate_episode
|
||||||
|
|
||||||
|
if success:
|
||||||
|
reward = 1.0
|
||||||
|
logging.info("Episode ended successfully with reward 1.0")
|
||||||
|
|
||||||
|
info["is_intervention"] = is_intervention
|
||||||
|
action_intervention = (
|
||||||
|
final_action[0] if isinstance(final_action, Tuple) else final_action
|
||||||
|
)
|
||||||
|
if isinstance(action_intervention, np.ndarray):
|
||||||
|
action_intervention = torch.from_numpy(action_intervention)
|
||||||
|
info["action_intervention"] = action_intervention
|
||||||
|
info["rerecord_episode"] = rerecord_episode
|
||||||
|
|
||||||
|
# If episode ended, reset the state
|
||||||
|
if terminated or truncated:
|
||||||
|
# Add success/failure information to info dict
|
||||||
|
info["next.success"] = success
|
||||||
|
|
||||||
|
# Auto reset if configured
|
||||||
|
if self.auto_reset:
|
||||||
|
obs, reset_info = self.reset()
|
||||||
|
info.update(reset_info)
|
||||||
|
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Clean up resources when environment closes."""
|
||||||
|
# Stop the controller
|
||||||
|
if hasattr(self, "controller"):
|
||||||
|
self.controller.stop()
|
||||||
|
|
||||||
|
# Call the parent close method
|
||||||
|
return self.env.close()
|
||||||
|
|
||||||
|
|
||||||
|
class ActionScaleWrapper(gym.ActionWrapper):
|
||||||
|
def __init__(self, env, ee_action_space_params=None):
|
||||||
|
super().__init__(env)
|
||||||
|
assert (
|
||||||
|
ee_action_space_params is not None
|
||||||
|
), "TODO: method implemented for ee action space only so far"
|
||||||
|
self.scale_vector = np.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
ee_action_space_params.x_step_size,
|
||||||
|
ee_action_space_params.y_step_size,
|
||||||
|
ee_action_space_params.z_step_size,
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def action(self, action):
|
||||||
|
is_intervention = False
|
||||||
|
if isinstance(action, tuple):
|
||||||
|
action, is_intervention = action
|
||||||
|
|
||||||
|
return action * self.scale_vector, is_intervention
|
||||||
|
|
||||||
|
|
||||||
def make_robot_env(
|
def make_robot_env(
|
||||||
|
@ -779,11 +1140,20 @@ def make_robot_env(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
display_cameras=cfg.env.wrapper.display_cameras,
|
display_cameras=cfg.env.wrapper.display_cameras,
|
||||||
delta=cfg.env.wrapper.delta_action,
|
delta=cfg.env.wrapper.delta_action,
|
||||||
use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions,
|
use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions
|
||||||
|
and cfg.env.wrapper.ee_action_space_params is None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add observation and image processing
|
# Add observation and image processing
|
||||||
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
if cfg.env.wrapper.add_joint_velocity_to_observation:
|
||||||
|
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
|
||||||
|
if cfg.env.wrapper.add_ee_pose_to_observation:
|
||||||
|
env = EEObservationWrapper(
|
||||||
|
env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds
|
||||||
|
)
|
||||||
|
|
||||||
|
env = ConvertToLeRobotObservation(env=env, device=cfg.env.device)
|
||||||
|
|
||||||
if cfg.env.wrapper.crop_params_dict is not None:
|
if cfg.env.wrapper.crop_params_dict is not None:
|
||||||
env = ImageCropResizeWrapper(
|
env = ImageCropResizeWrapper(
|
||||||
env=env,
|
env=env,
|
||||||
|
@ -792,14 +1162,37 @@ def make_robot_env(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add reward computation and control wrappers
|
# Add reward computation and control wrappers
|
||||||
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||||
env = TimeLimitWrapper(
|
env = TimeLimitWrapper(
|
||||||
env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps
|
env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps
|
||||||
)
|
)
|
||||||
env = KeyboardInterfaceWrapper(env=env)
|
if cfg.env.wrapper.ee_action_space_params is not None:
|
||||||
env = ResetWrapper(
|
env = EEActionWrapper(
|
||||||
env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s
|
env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
cfg.env.wrapper.ee_action_space_params is not None
|
||||||
|
and cfg.env.wrapper.ee_action_space_params.use_gamepad
|
||||||
|
):
|
||||||
|
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params)
|
||||||
|
env = GamepadControlWrapper(
|
||||||
|
env=env,
|
||||||
|
x_step_size=cfg.env.wrapper.ee_action_space_params.x_step_size,
|
||||||
|
y_step_size=cfg.env.wrapper.ee_action_space_params.y_step_size,
|
||||||
|
z_step_size=cfg.env.wrapper.ee_action_space_params.z_step_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
env = KeyboardInterfaceWrapper(env=env)
|
||||||
|
|
||||||
|
env = ResetWrapper(
|
||||||
|
env=env,
|
||||||
|
reset_pose=cfg.env.wrapper.fixed_reset_joint_positions,
|
||||||
|
reset_time_s=cfg.env.wrapper.reset_time_s,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
cfg.env.wrapper.ee_action_space_params is None
|
||||||
|
and cfg.env.wrapper.joint_masking_action_space is not None
|
||||||
|
):
|
||||||
env = JointMaskingActionSpace(
|
env = JointMaskingActionSpace(
|
||||||
env=env, mask=cfg.env.wrapper.joint_masking_action_space
|
env=env, mask=cfg.env.wrapper.joint_masking_action_space
|
||||||
)
|
)
|
||||||
|
@ -807,8 +1200,6 @@ def make_robot_env(
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
# batched version of the env that returns an observation of shape (b, c)
|
|
||||||
|
|
||||||
|
|
||||||
def get_classifier(pretrained_path, config_path, device="mps"):
|
def get_classifier(pretrained_path, config_path, device="mps"):
|
||||||
if pretrained_path is None or config_path is None:
|
if pretrained_path is None or config_path is None:
|
||||||
|
@ -834,6 +1225,134 @@ def get_classifier(pretrained_path, config_path, device="mps"):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def record_dataset(
|
||||||
|
env,
|
||||||
|
repo_id,
|
||||||
|
root=None,
|
||||||
|
num_episodes=1,
|
||||||
|
control_time_s=20,
|
||||||
|
fps=30,
|
||||||
|
push_to_hub=True,
|
||||||
|
task_description="",
|
||||||
|
policy=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Record a dataset of robot interactions using either a policy or teleop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to record from
|
||||||
|
repo_id: Repository ID for dataset storage
|
||||||
|
root: Local root directory for dataset (optional)
|
||||||
|
num_episodes: Number of episodes to record
|
||||||
|
control_time_s: Maximum episode length in seconds
|
||||||
|
fps: Frames per second for recording
|
||||||
|
push_to_hub: Whether to push dataset to Hugging Face Hub
|
||||||
|
task_description: Description of the task being recorded
|
||||||
|
policy: Optional policy to generate actions (if None, uses teleop)
|
||||||
|
"""
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
# Setup initial action (zero action if using teleop)
|
||||||
|
dummy_action = env.action_space.sample()
|
||||||
|
dummy_action = (torch.from_numpy(dummy_action[0] * 0.0), False)
|
||||||
|
action = dummy_action
|
||||||
|
|
||||||
|
# Configure dataset features based on environment spaces
|
||||||
|
features = {
|
||||||
|
"observation.state": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": env.observation_space["observation.state"].shape,
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": env.action_space[0].shape,
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||||
|
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add image features
|
||||||
|
for key in env.observation_space:
|
||||||
|
if "image" in key:
|
||||||
|
features[key] = {
|
||||||
|
"dtype": "video",
|
||||||
|
"shape": env.observation_space[key].shape,
|
||||||
|
"names": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create dataset
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
repo_id,
|
||||||
|
fps,
|
||||||
|
root=root,
|
||||||
|
use_videos=True,
|
||||||
|
image_writer_threads=4,
|
||||||
|
image_writer_processes=0,
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record episodes
|
||||||
|
episode_index = 0
|
||||||
|
while episode_index < num_episodes:
|
||||||
|
obs, _ = env.reset()
|
||||||
|
start_episode_t = time.perf_counter()
|
||||||
|
log_say(f"Recording episode {episode_index}", play_sounds=True)
|
||||||
|
|
||||||
|
# Run episode steps
|
||||||
|
while time.perf_counter() - start_episode_t < control_time_s:
|
||||||
|
start_loop_t = time.perf_counter()
|
||||||
|
|
||||||
|
# Get action from policy if available
|
||||||
|
if policy is not None:
|
||||||
|
action = policy.select_action(obs)
|
||||||
|
|
||||||
|
# Step environment
|
||||||
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
|
|
||||||
|
# Check if episode needs to be rerecorded
|
||||||
|
if info.get("rerecord_episode", False):
|
||||||
|
break
|
||||||
|
|
||||||
|
# For teleop, get action from intervention
|
||||||
|
if policy is None:
|
||||||
|
action = {
|
||||||
|
"action": info["action_intervention"].cpu().squeeze(0).float()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process observation for dataset
|
||||||
|
obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()}
|
||||||
|
|
||||||
|
# Add frame to dataset
|
||||||
|
frame = {**obs, **action}
|
||||||
|
frame["next.reward"] = reward
|
||||||
|
frame["next.done"] = terminated or truncated
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
# Maintain consistent timing
|
||||||
|
if fps:
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
|
if terminated or truncated:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Handle episode recording
|
||||||
|
if info.get("rerecord_episode", False):
|
||||||
|
dataset.clear_episode_buffer()
|
||||||
|
logging.info(f"Re-recording episode {episode_index}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
dataset.save_episode(task_description)
|
||||||
|
episode_index += 1
|
||||||
|
|
||||||
|
# Finalize dataset
|
||||||
|
dataset.consolidate(run_compute_stats=True)
|
||||||
|
if push_to_hub:
|
||||||
|
dataset.push_to_hub(repo_id)
|
||||||
|
|
||||||
|
|
||||||
def replay_episode(env, repo_id, root=None, episode=0):
|
def replay_episode(env, repo_id, root=None, episode=0):
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
@ -841,14 +1360,16 @@ def replay_episode(env, repo_id, root=None, episode=0):
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
repo_id, root=root, episodes=[episode], local_files_only=local_files_only
|
repo_id, root=root, episodes=[episode], local_files_only=local_files_only
|
||||||
)
|
)
|
||||||
|
env.reset()
|
||||||
|
|
||||||
actions = dataset.hf_dataset.select_columns("action")
|
actions = dataset.hf_dataset.select_columns("action")
|
||||||
|
|
||||||
for idx in range(dataset.num_frames):
|
for idx in range(dataset.num_frames):
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
|
|
||||||
action = actions[idx]["action"][:4]
|
action = actions[idx]["action"][:4]
|
||||||
print(action)
|
env.step((action, False))
|
||||||
env.step((action / env.unwrapped.delta, False))
|
# env.step((action / env.unwrapped.delta, False))
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_episode_t
|
dt_s = time.perf_counter() - start_episode_t
|
||||||
busy_wait(1 / 10 - dt_s)
|
busy_wait(1 / 10 - dt_s)
|
||||||
|
@ -875,14 +1396,6 @@ if __name__ == "__main__":
|
||||||
help=(
|
help=(
|
||||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||||
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--config",
|
|
||||||
help=(
|
|
||||||
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
|
|
||||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -929,11 +1442,30 @@ if __name__ == "__main__":
|
||||||
help="Repo ID of the episode to replay",
|
help="Repo ID of the episode to replay",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--replay-root", type=str, default=None, help="Root of the dataset to replay"
|
"--dataset-root", type=str, default=None, help="Root of the dataset to replay"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--replay-episode", type=int, default=0, help="Episode to replay"
|
"--replay-episode", type=int, default=0, help="Episode to replay"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--record-repo-id",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Repo ID of the dataset to record",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--record-num-episodes",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of episodes to record",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--record-episode-task",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Single line description of the task to record",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||||
|
@ -948,17 +1480,40 @@ if __name__ == "__main__":
|
||||||
env = make_robot_env(
|
env = make_robot_env(
|
||||||
robot,
|
robot,
|
||||||
reward_classifier,
|
reward_classifier,
|
||||||
cfg.env, # .wrapper,
|
cfg, # .wrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
env.reset()
|
if args.record_repo_id is not None:
|
||||||
|
policy = None
|
||||||
|
if args.pretrained_policy_name_or_path is not None:
|
||||||
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
|
|
||||||
|
policy = SACPolicy.from_pretrained(args.pretrained_policy_name_or_path)
|
||||||
|
policy.to(cfg.device)
|
||||||
|
policy.eval()
|
||||||
|
|
||||||
|
record_dataset(
|
||||||
|
env,
|
||||||
|
args.record_repo_id,
|
||||||
|
root=args.dataset_root,
|
||||||
|
num_episodes=args.record_num_episodes,
|
||||||
|
fps=args.fps,
|
||||||
|
task_description=args.record_episode_task,
|
||||||
|
policy=policy,
|
||||||
|
)
|
||||||
|
exit()
|
||||||
|
|
||||||
if args.replay_repo_id is not None:
|
if args.replay_repo_id is not None:
|
||||||
replay_episode(
|
replay_episode(
|
||||||
env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode
|
env,
|
||||||
|
args.replay_repo_id,
|
||||||
|
root=args.dataset_root,
|
||||||
|
episode=args.replay_episode,
|
||||||
)
|
)
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
|
env.reset()
|
||||||
|
|
||||||
# Retrieve the robot's action space for joint commands.
|
# Retrieve the robot's action space for joint commands.
|
||||||
action_space_robot = env.action_space.spaces[0]
|
action_space_robot = env.action_space.spaces[0]
|
||||||
|
|
||||||
|
@ -967,9 +1522,11 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
|
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
|
||||||
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
|
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
|
||||||
alpha = 0.4
|
alpha = 1.0
|
||||||
|
|
||||||
while True:
|
num_episode = 0
|
||||||
|
sucesses = []
|
||||||
|
while num_episode < 20:
|
||||||
start_loop_s = time.perf_counter()
|
start_loop_s = time.perf_counter()
|
||||||
# Sample a new random action from the robot's action space.
|
# Sample a new random action from the robot's action space.
|
||||||
new_random_action = action_space_robot.sample()
|
new_random_action = action_space_robot.sample()
|
||||||
|
@ -981,7 +1538,12 @@ if __name__ == "__main__":
|
||||||
(torch.from_numpy(smoothed_action), False)
|
(torch.from_numpy(smoothed_action), False)
|
||||||
)
|
)
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
|
sucesses.append(reward)
|
||||||
env.reset()
|
env.reset()
|
||||||
|
num_episode += 1
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_s
|
dt_s = time.perf_counter() - start_loop_s
|
||||||
busy_wait(1 / args.fps - dt_s)
|
busy_wait(1 / args.fps - dt_s)
|
||||||
|
|
||||||
|
logging.info(f"Success after 20 steps {sucesses}")
|
||||||
|
logging.info(f"success rate {sum(sucesses)/ len(sucesses)}")
|
||||||
|
|
|
@ -0,0 +1,543 @@
|
||||||
|
import numpy as np
|
||||||
|
from scipy.spatial.transform import Rotation
|
||||||
|
|
||||||
|
|
||||||
|
def skew_symmetric(w):
|
||||||
|
"""Creates the skew-symmetric matrix from a 3D vector."""
|
||||||
|
return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]])
|
||||||
|
|
||||||
|
|
||||||
|
def rodrigues_rotation(w, theta):
|
||||||
|
"""Computes the rotation matrix using Rodrigues' formula."""
|
||||||
|
w_hat = skew_symmetric(w)
|
||||||
|
return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
||||||
|
|
||||||
|
|
||||||
|
def screw_axis_to_transform(S, theta):
|
||||||
|
"""Converts a screw axis to a 4x4 transformation matrix."""
|
||||||
|
S_w = S[:3]
|
||||||
|
S_v = S[3:]
|
||||||
|
if np.allclose(S_w, 0) and np.linalg.norm(S_v) == 1: # Pure translation
|
||||||
|
T = np.eye(4)
|
||||||
|
T[:3, 3] = S_v * theta
|
||||||
|
elif np.linalg.norm(S_w) == 1: # Rotation and translation
|
||||||
|
w_hat = skew_symmetric(S_w)
|
||||||
|
R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
||||||
|
t = (
|
||||||
|
np.eye(3) * theta
|
||||||
|
+ (1 - np.cos(theta)) * w_hat
|
||||||
|
+ (theta - np.sin(theta)) * w_hat @ w_hat
|
||||||
|
) @ S_v
|
||||||
|
T = np.eye(4)
|
||||||
|
T[:3, :3] = R
|
||||||
|
T[:3, 3] = t
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid screw axis parameters")
|
||||||
|
return T
|
||||||
|
|
||||||
|
|
||||||
|
def pose_difference_se3(pose1, pose2):
|
||||||
|
"""
|
||||||
|
Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices.
|
||||||
|
|
||||||
|
pose1 - pose2
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pose1: A 4x4 numpy array representing the first pose.
|
||||||
|
pose2: A 4x4 numpy array representing the second pose.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple (translation_diff, rotation_diff) where:
|
||||||
|
- translation_diff is a 3x1 numpy array representing the translational difference.
|
||||||
|
- rotation_diff is a 3x1 numpy array representing the rotational difference in axis-angle representation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Extract rotation matrices from poses
|
||||||
|
R1 = pose1[:3, :3]
|
||||||
|
R2 = pose2[:3, :3]
|
||||||
|
|
||||||
|
# Calculate translational difference
|
||||||
|
translation_diff = pose1[:3, 3] - pose2[:3, 3]
|
||||||
|
|
||||||
|
# Calculate rotational difference using scipy's Rotation library
|
||||||
|
R_diff = Rotation.from_matrix(R1 @ R2.T)
|
||||||
|
rotation_diff = R_diff.as_rotvec() # Convert to axis-angle representation
|
||||||
|
|
||||||
|
return np.concatenate([translation_diff, rotation_diff])
|
||||||
|
|
||||||
|
|
||||||
|
def se3_error(target_pose, current_pose):
|
||||||
|
pos_error = target_pose[:3, 3] - current_pose[:3, 3]
|
||||||
|
R_target = target_pose[:3, :3]
|
||||||
|
R_current = current_pose[:3, :3]
|
||||||
|
R_error = R_target @ R_current.T
|
||||||
|
rot_error = Rotation.from_matrix(R_error).as_rotvec()
|
||||||
|
return np.concatenate([pos_error, rot_error])
|
||||||
|
|
||||||
|
|
||||||
|
class RobotKinematics:
|
||||||
|
"""Robot kinematics class supporting multiple robot models."""
|
||||||
|
|
||||||
|
# Robot measurements dictionary
|
||||||
|
ROBOT_MEASUREMENTS = {
|
||||||
|
"koch": {
|
||||||
|
"gripper": [0.239, -0.001, 0.024],
|
||||||
|
"wrist": [0.209, 0, 0.024],
|
||||||
|
"forearm": [0.108, 0, 0.02],
|
||||||
|
"humerus": [0, 0, 0.036],
|
||||||
|
"shoulder": [0, 0, 0],
|
||||||
|
"base": [0, 0, 0.02],
|
||||||
|
},
|
||||||
|
"so100": {
|
||||||
|
"gripper": [0.320, 0, 0.050],
|
||||||
|
"wrist": [0.278, 0, 0.050],
|
||||||
|
"forearm": [0.143, 0, 0.044],
|
||||||
|
"humerus": [0.031, 0, 0.072],
|
||||||
|
"shoulder": [0, 0, 0],
|
||||||
|
"base": [0, 0, 0.02],
|
||||||
|
},
|
||||||
|
"moss": {
|
||||||
|
"gripper": [0.246, 0.013, 0.111],
|
||||||
|
"wrist": [0.245, 0.002, 0.064],
|
||||||
|
"forearm": [0.122, 0, 0.064],
|
||||||
|
"humerus": [0.001, 0.001, 0.063],
|
||||||
|
"shoulder": [0, 0, 0],
|
||||||
|
"base": [0, 0, 0.02],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, robot_type="so100"):
|
||||||
|
"""Initialize kinematics for the specified robot type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot_type: String specifying the robot model ("koch", "so100", or "moss")
|
||||||
|
"""
|
||||||
|
if robot_type not in self.ROBOT_MEASUREMENTS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.robot_type = robot_type
|
||||||
|
self.measurements = self.ROBOT_MEASUREMENTS[robot_type]
|
||||||
|
|
||||||
|
# Initialize all transformation matrices and screw axes
|
||||||
|
self._setup_transforms()
|
||||||
|
|
||||||
|
def _create_translation_matrix(self, x=0, y=0, z=0):
|
||||||
|
"""Create a 4x4 translation matrix."""
|
||||||
|
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]])
|
||||||
|
|
||||||
|
def _setup_transforms(self):
|
||||||
|
"""Setup all transformation matrices and screw axes for the robot."""
|
||||||
|
# Set up rotation matrices (constant across robot types)
|
||||||
|
|
||||||
|
# Gripper orientation
|
||||||
|
self.gripper_X0 = np.array(
|
||||||
|
[
|
||||||
|
[1, 0, 0, 0],
|
||||||
|
[0, 0, 1, 0],
|
||||||
|
[0, -1, 0, 0],
|
||||||
|
[0, 0, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wrist orientation
|
||||||
|
self.wrist_X0 = np.array(
|
||||||
|
[
|
||||||
|
[0, -1, 0, 0],
|
||||||
|
[1, 0, 0, 0],
|
||||||
|
[0, 0, 1, 0],
|
||||||
|
[0, 0, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Base orientation
|
||||||
|
self.base_X0 = np.array(
|
||||||
|
[
|
||||||
|
[0, 0, 1, 0],
|
||||||
|
[1, 0, 0, 0],
|
||||||
|
[0, 1, 0, 0],
|
||||||
|
[0, 0, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gripper
|
||||||
|
# Screw axis of gripper frame wrt base frame
|
||||||
|
self.S_BG = np.array(
|
||||||
|
[
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
self.measurements["gripper"][2],
|
||||||
|
-self.measurements["gripper"][1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gripper origin to centroid transform
|
||||||
|
self.X_GoGc = self._create_translation_matrix(x=0.07)
|
||||||
|
|
||||||
|
# Gripper origin to tip transform
|
||||||
|
self.X_GoGt = self._create_translation_matrix(x=0.12)
|
||||||
|
|
||||||
|
# 0-position gripper frame pose wrt base
|
||||||
|
self.X_BoGo = self._create_translation_matrix(
|
||||||
|
x=self.measurements["gripper"][0],
|
||||||
|
y=self.measurements["gripper"][1],
|
||||||
|
z=self.measurements["gripper"][2],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wrist
|
||||||
|
# Screw axis of wrist frame wrt base frame
|
||||||
|
self.S_BR = np.array(
|
||||||
|
[0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 0-position origin to centroid transform
|
||||||
|
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
|
||||||
|
|
||||||
|
# 0-position wrist frame pose wrt base
|
||||||
|
self.X_BR = self._create_translation_matrix(
|
||||||
|
x=self.measurements["wrist"][0],
|
||||||
|
y=self.measurements["wrist"][1],
|
||||||
|
z=self.measurements["wrist"][2],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forearm
|
||||||
|
# Screw axis of forearm frame wrt base frame
|
||||||
|
self.S_BF = np.array(
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
-self.measurements["forearm"][2],
|
||||||
|
0,
|
||||||
|
self.measurements["forearm"][0],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forearm origin + centroid transform
|
||||||
|
self.X_FoFc = self._create_translation_matrix(x=0.036)
|
||||||
|
|
||||||
|
# 0-position forearm frame pose wrt base
|
||||||
|
self.X_BF = self._create_translation_matrix(
|
||||||
|
x=self.measurements["forearm"][0],
|
||||||
|
y=self.measurements["forearm"][1],
|
||||||
|
z=self.measurements["forearm"][2],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Humerus
|
||||||
|
# Screw axis of humerus frame wrt base frame
|
||||||
|
self.S_BH = np.array(
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
-1,
|
||||||
|
0,
|
||||||
|
self.measurements["humerus"][2],
|
||||||
|
0,
|
||||||
|
-self.measurements["humerus"][0],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Humerus origin to centroid transform
|
||||||
|
self.X_HoHc = self._create_translation_matrix(x=0.0475)
|
||||||
|
|
||||||
|
# 0-position humerus frame pose wrt base
|
||||||
|
self.X_BH = self._create_translation_matrix(
|
||||||
|
x=self.measurements["humerus"][0],
|
||||||
|
y=self.measurements["humerus"][1],
|
||||||
|
z=self.measurements["humerus"][2],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shoulder
|
||||||
|
# Screw axis of shoulder frame wrt Base frame
|
||||||
|
self.S_BS = np.array([0, 0, -1, 0, 0, 0])
|
||||||
|
|
||||||
|
# Shoulder origin to centroid transform
|
||||||
|
self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235)
|
||||||
|
|
||||||
|
# 0-position shoulder frame pose wrt base
|
||||||
|
self.X_BS = self._create_translation_matrix(
|
||||||
|
x=self.measurements["shoulder"][0],
|
||||||
|
y=self.measurements["shoulder"][1],
|
||||||
|
z=self.measurements["shoulder"][2],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Base
|
||||||
|
# Base origin to centroid transform
|
||||||
|
self.X_BoBc = self._create_translation_matrix(y=0.015)
|
||||||
|
|
||||||
|
# World to base transform
|
||||||
|
self.X_WoBo = self._create_translation_matrix(
|
||||||
|
x=self.measurements["base"][0],
|
||||||
|
y=self.measurements["base"][1],
|
||||||
|
z=self.measurements["base"][2],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pre-compute gripper post-multiplication matrix
|
||||||
|
self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0
|
||||||
|
|
||||||
|
def fk_base(self):
|
||||||
|
"""Forward kinematics for the base frame."""
|
||||||
|
return self.X_WoBo @ self.X_BoBc @ self.base_X0
|
||||||
|
|
||||||
|
def fk_shoulder(self, robot_pos_deg):
|
||||||
|
"""Forward kinematics for the shoulder frame."""
|
||||||
|
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||||
|
return (
|
||||||
|
self.X_WoBo
|
||||||
|
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||||
|
@ self.X_SoSc
|
||||||
|
@ self.X_BS
|
||||||
|
)
|
||||||
|
|
||||||
|
def fk_humerus(self, robot_pos_deg):
|
||||||
|
"""Forward kinematics for the humerus frame."""
|
||||||
|
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||||
|
return (
|
||||||
|
self.X_WoBo
|
||||||
|
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||||
|
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||||
|
@ self.X_HoHc
|
||||||
|
@ self.X_BH
|
||||||
|
)
|
||||||
|
|
||||||
|
def fk_forearm(self, robot_pos_deg):
|
||||||
|
"""Forward kinematics for the forearm frame."""
|
||||||
|
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||||
|
return (
|
||||||
|
self.X_WoBo
|
||||||
|
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||||
|
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||||
|
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||||
|
@ self.X_FoFc
|
||||||
|
@ self.X_BF
|
||||||
|
)
|
||||||
|
|
||||||
|
def fk_wrist(self, robot_pos_deg):
|
||||||
|
"""Forward kinematics for the wrist frame."""
|
||||||
|
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||||
|
return (
|
||||||
|
self.X_WoBo
|
||||||
|
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||||
|
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||||
|
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||||
|
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
||||||
|
@ self.X_RoRc
|
||||||
|
@ self.X_BR
|
||||||
|
@ self.wrist_X0
|
||||||
|
)
|
||||||
|
|
||||||
|
def fk_gripper(self, robot_pos_deg):
|
||||||
|
"""Forward kinematics for the gripper frame."""
|
||||||
|
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||||
|
return (
|
||||||
|
self.X_WoBo
|
||||||
|
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||||
|
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||||
|
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||||
|
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
||||||
|
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
|
||||||
|
@ self._fk_gripper_post
|
||||||
|
)
|
||||||
|
|
||||||
|
def fk_gripper_tip(self, robot_pos_deg):
|
||||||
|
"""Forward kinematics for the gripper tip frame."""
|
||||||
|
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||||
|
return (
|
||||||
|
self.X_WoBo
|
||||||
|
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||||
|
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
|
||||||
|
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
|
||||||
|
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
|
||||||
|
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
|
||||||
|
@ self.X_GoGt
|
||||||
|
@ self.X_BoGo
|
||||||
|
@ self.gripper_X0
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_jacobian(self, robot_pos_deg, fk_func=None):
|
||||||
|
"""Finite differences to compute the Jacobian.
|
||||||
|
J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change
|
||||||
|
in the jth joint's velocity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot_pos_deg: Current joint positions in degrees
|
||||||
|
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||||
|
"""
|
||||||
|
if fk_func is None:
|
||||||
|
fk_func = self.fk_gripper
|
||||||
|
|
||||||
|
eps = 1e-8
|
||||||
|
jac = np.zeros(shape=(6, 5))
|
||||||
|
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
|
||||||
|
for el_ix in range(len(robot_pos_deg[:-1])):
|
||||||
|
delta *= 0
|
||||||
|
delta[el_ix] = eps / 2
|
||||||
|
Sdot = (
|
||||||
|
pose_difference_se3(
|
||||||
|
fk_func(robot_pos_deg[:-1] + delta),
|
||||||
|
fk_func(robot_pos_deg[:-1] - delta),
|
||||||
|
)
|
||||||
|
/ eps
|
||||||
|
)
|
||||||
|
jac[:, el_ix] = Sdot
|
||||||
|
return jac
|
||||||
|
|
||||||
|
def compute_positional_jacobian(self, robot_pos_deg, fk_func=None):
|
||||||
|
"""Finite differences to compute the positional Jacobian.
|
||||||
|
J(i, j) represents how the ith component of the end-effector's position changes wrt a small change
|
||||||
|
in the jth joint's velocity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot_pos_deg: Current joint positions in degrees
|
||||||
|
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||||
|
"""
|
||||||
|
if fk_func is None:
|
||||||
|
fk_func = self.fk_gripper
|
||||||
|
|
||||||
|
eps = 1e-8
|
||||||
|
jac = np.zeros(shape=(3, 5))
|
||||||
|
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
|
||||||
|
for el_ix in range(len(robot_pos_deg[:-1])):
|
||||||
|
delta *= 0
|
||||||
|
delta[el_ix] = eps / 2
|
||||||
|
Sdot = (
|
||||||
|
fk_func(robot_pos_deg[:-1] + delta)[:3, 3]
|
||||||
|
- fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
|
||||||
|
) / eps
|
||||||
|
jac[:, el_ix] = Sdot
|
||||||
|
return jac
|
||||||
|
|
||||||
|
def ik(
|
||||||
|
self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None
|
||||||
|
):
|
||||||
|
"""Inverse kinematics using gradient descent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_joint_state: Initial joint positions in degrees
|
||||||
|
desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix
|
||||||
|
position_only: If True, only match end-effector position, not orientation
|
||||||
|
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Joint positions in degrees that achieve the desired end-effector pose
|
||||||
|
"""
|
||||||
|
if fk_func is None:
|
||||||
|
fk_func = self.fk_gripper
|
||||||
|
|
||||||
|
# Do gradient descent.
|
||||||
|
max_iterations = 5
|
||||||
|
learning_rate = 1
|
||||||
|
for _ in range(max_iterations):
|
||||||
|
current_ee_pose = fk_func(current_joint_state)
|
||||||
|
if not position_only:
|
||||||
|
error = se3_error(desired_ee_pose, current_ee_pose)
|
||||||
|
jac = self.compute_jacobian(current_joint_state, fk_func)
|
||||||
|
else:
|
||||||
|
error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3]
|
||||||
|
jac = self.compute_positional_jacobian(current_joint_state, fk_func)
|
||||||
|
delta_angles = np.linalg.pinv(jac) @ error
|
||||||
|
current_joint_state[:-1] += learning_rate * delta_angles
|
||||||
|
|
||||||
|
if np.linalg.norm(error) < 5e-3:
|
||||||
|
return current_joint_state
|
||||||
|
return current_joint_state
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import time
|
||||||
|
|
||||||
|
def run_test(robot_type):
|
||||||
|
"""Run test suite for a specific robot type."""
|
||||||
|
print(f"\n--- Testing {robot_type.upper()} Robot ---")
|
||||||
|
|
||||||
|
# Initialize kinematics for this robot
|
||||||
|
robot = RobotKinematics(robot_type)
|
||||||
|
|
||||||
|
# Test 1: Forward kinematics consistency
|
||||||
|
print("Test 1: Forward kinematics consistency")
|
||||||
|
test_angles = np.array(
|
||||||
|
[30, 45, -30, 20, 10, 0]
|
||||||
|
) # Example joint angles in degrees
|
||||||
|
|
||||||
|
# Calculate FK for different joints
|
||||||
|
shoulder_pose = robot.fk_shoulder(test_angles)
|
||||||
|
humerus_pose = robot.fk_humerus(test_angles)
|
||||||
|
forearm_pose = robot.fk_forearm(test_angles)
|
||||||
|
wrist_pose = robot.fk_wrist(test_angles)
|
||||||
|
gripper_pose = robot.fk_gripper(test_angles)
|
||||||
|
gripper_tip_pose = robot.fk_gripper_tip(test_angles)
|
||||||
|
|
||||||
|
# Check that poses form a consistent kinematic chain (positions should be progressively further from origin)
|
||||||
|
distances = [
|
||||||
|
np.linalg.norm(shoulder_pose[:3, 3]),
|
||||||
|
np.linalg.norm(humerus_pose[:3, 3]),
|
||||||
|
np.linalg.norm(forearm_pose[:3, 3]),
|
||||||
|
np.linalg.norm(wrist_pose[:3, 3]),
|
||||||
|
np.linalg.norm(gripper_pose[:3, 3]),
|
||||||
|
np.linalg.norm(gripper_tip_pose[:3, 3]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check if distances generally increase along the chain
|
||||||
|
is_consistent = all(
|
||||||
|
distances[i] <= distances[i + 1] for i in range(len(distances) - 1)
|
||||||
|
)
|
||||||
|
print(f" Pose distances from origin: {[round(d, 3) for d in distances]}")
|
||||||
|
print(
|
||||||
|
f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 2: Jacobian computation
|
||||||
|
print("Test 2: Jacobian computation")
|
||||||
|
jacobian = robot.compute_jacobian(test_angles)
|
||||||
|
positional_jacobian = robot.compute_positional_jacobian(test_angles)
|
||||||
|
|
||||||
|
# Check shapes
|
||||||
|
jacobian_shape_ok = jacobian.shape == (6, 5)
|
||||||
|
pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5)
|
||||||
|
|
||||||
|
print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}")
|
||||||
|
print(
|
||||||
|
f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 3: Inverse kinematics
|
||||||
|
print("Test 3: Inverse kinematics (position only)")
|
||||||
|
|
||||||
|
# Generate target pose from known joint angles
|
||||||
|
original_angles = np.array([10, 20, 30, -10, 5, 0])
|
||||||
|
target_pose = robot.fk_gripper(original_angles)
|
||||||
|
|
||||||
|
# Start IK from a different position
|
||||||
|
initial_guess = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
||||||
|
|
||||||
|
# Measure IK performance
|
||||||
|
start_time = time.time()
|
||||||
|
computed_angles = robot.ik(initial_guess.copy(), target_pose)
|
||||||
|
ik_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Compute resulting pose from IK solution
|
||||||
|
result_pose = robot.fk_gripper(computed_angles)
|
||||||
|
|
||||||
|
# Calculate position error
|
||||||
|
pos_error = np.linalg.norm(target_pose[:3, 3] - result_pose[:3, 3])
|
||||||
|
passed = pos_error < 0.01 # Accept errors less than 1cm
|
||||||
|
|
||||||
|
print(f" IK computation time: {ik_time:.4f} seconds")
|
||||||
|
print(f" Position error: {pos_error:.4f}")
|
||||||
|
print(f" IK position accuracy: {'PASSED' if passed else 'FAILED'}")
|
||||||
|
|
||||||
|
return is_consistent and jacobian_shape_ok and pos_jacobian_shape_ok and passed
|
||||||
|
|
||||||
|
# Run tests for all robot types
|
||||||
|
results = {}
|
||||||
|
for robot_type in ["koch", "so100", "moss"]:
|
||||||
|
results[robot_type] = run_test(robot_type)
|
||||||
|
|
||||||
|
# Print overall summary
|
||||||
|
print("\n=== Test Summary ===")
|
||||||
|
all_passed = all(results.values())
|
||||||
|
for robot_type, passed in results.items():
|
||||||
|
print(f"{robot_type.upper()}: {'PASSED' if passed else 'FAILED'}")
|
||||||
|
print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}")
|
|
@ -315,16 +315,49 @@ def start_learner_server(
|
||||||
|
|
||||||
|
|
||||||
def check_nan_in_transition(
|
def check_nan_in_transition(
|
||||||
observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor
|
observations: torch.Tensor,
|
||||||
):
|
actions: torch.Tensor,
|
||||||
for k in observations:
|
next_state: torch.Tensor,
|
||||||
if torch.isnan(observations[k]).any():
|
raise_error: bool = False,
|
||||||
logging.error(f"observations[{k}] contains NaN values")
|
) -> bool:
|
||||||
for k in next_state:
|
"""
|
||||||
if torch.isnan(next_state[k]).any():
|
Check for NaN values in transition data.
|
||||||
logging.error(f"next_state[{k}] contains NaN values")
|
|
||||||
|
Args:
|
||||||
|
observations: Dictionary of observation tensors
|
||||||
|
actions: Action tensor
|
||||||
|
next_state: Dictionary of next state tensors
|
||||||
|
raise_error: If True, raises ValueError when NaN is detected
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if NaN values were detected, False otherwise
|
||||||
|
"""
|
||||||
|
nan_detected = False
|
||||||
|
|
||||||
|
# Check observations
|
||||||
|
for key, tensor in observations.items():
|
||||||
|
if torch.isnan(tensor).any():
|
||||||
|
logging.error(f"observations[{key}] contains NaN values")
|
||||||
|
nan_detected = True
|
||||||
|
if raise_error:
|
||||||
|
raise ValueError(f"NaN detected in observations[{key}]")
|
||||||
|
|
||||||
|
# Check next state
|
||||||
|
for key, tensor in next_state.items():
|
||||||
|
if torch.isnan(tensor).any():
|
||||||
|
logging.error(f"next_state[{key}] contains NaN values")
|
||||||
|
nan_detected = True
|
||||||
|
if raise_error:
|
||||||
|
raise ValueError(f"NaN detected in next_state[{key}]")
|
||||||
|
|
||||||
|
# Check actions
|
||||||
if torch.isnan(actions).any():
|
if torch.isnan(actions).any():
|
||||||
logging.error("actions contains NaN values")
|
logging.error("actions contains NaN values")
|
||||||
|
nan_detected = True
|
||||||
|
if raise_error:
|
||||||
|
raise ValueError("NaN detected in actions")
|
||||||
|
|
||||||
|
return nan_detected
|
||||||
|
|
||||||
|
|
||||||
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||||
|
@ -460,9 +493,18 @@ def add_actor_information_and_train(
|
||||||
|
|
||||||
for transition in transition_list:
|
for transition in transition_list:
|
||||||
transition = move_transition_to_device(transition, device=device)
|
transition = move_transition_to_device(transition, device=device)
|
||||||
|
if check_nan_in_transition(
|
||||||
|
transition["state"], transition["action"], transition["next_state"]
|
||||||
|
):
|
||||||
|
logging.warning("NaN detected in transition, skipping")
|
||||||
|
continue
|
||||||
replay_buffer.add(**transition)
|
replay_buffer.add(**transition)
|
||||||
if transition.get("complementary_info", {}).get("is_intervention"):
|
|
||||||
|
if cfg.dataset_repo_id is not None and transition.get(
|
||||||
|
"complementary_info", {}
|
||||||
|
).get("is_intervention"):
|
||||||
offline_replay_buffer.add(**transition)
|
offline_replay_buffer.add(**transition)
|
||||||
|
|
||||||
logging.debug("[LEARNER] Received transitions")
|
logging.debug("[LEARNER] Received transitions")
|
||||||
logging.debug("[LEARNER] Waiting for interactions")
|
logging.debug("[LEARNER] Waiting for interactions")
|
||||||
while not interaction_message_queue.empty() and not shutdown_event.is_set():
|
while not interaction_message_queue.empty() and not shutdown_event.is_set():
|
||||||
|
|
Loading…
Reference in New Issue