Merge branch 'user/adil-zouitine/2025-1-7-port-hil-serl-new' into hil-serl/optimize_learning_loop

This commit is contained in:
Adil Zouitine 2025-03-19 14:24:04 +01:00 committed by GitHub
commit 99cb691692
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 2396 additions and 317 deletions

View File

@ -137,6 +137,7 @@ class PixelWrapper(gym.Wrapper):
return self._get_obs(obs), reward, terminated, truncated, info
# TODO: Remove this
class ConvertToLeRobotEnv(gym.Wrapper):
def __init__(self, env, num_envs):
super().__init__(env)

View File

@ -84,10 +84,12 @@ class SACConfig:
latent_dim: int = 256
target_entropy: float | None = None
use_backup_entropy: bool = True
grad_clip_norm: float = 40.0
critic_network_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"hidden_dims": [256, 256],
"activate_final": True,
"final_activation": None,
}
)
actor_network_kwargs: dict[str, Any] = field(
@ -101,6 +103,6 @@ class SACConfig:
"use_tanh_squash": True,
"log_std_min": -5,
"log_std_max": 2,
"init_final": 0.005,
"init_final": 0.05,
}
)

View File

@ -18,7 +18,7 @@
# TODO: (1) better device management
from copy import deepcopy
from typing import Callable, Optional, Tuple, Union, Dict
from typing import Callable, Optional, Tuple, Union, Dict, List
from pathlib import Path
import einops
@ -88,33 +88,33 @@ class SACPolicy(
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
# Create a list of critic heads
critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
self.critic_ensemble = CriticEnsemble(
encoder=encoder_critic,
ensemble=Ensemble(
[
CriticHead(
input_dim=encoder_critic.output_dim
+ config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
ensemble=critic_heads,
output_normalization=self.normalize_targets,
)
# Create target critic heads as deepcopies of the original critic heads
target_critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
self.critic_target = CriticEnsemble(
encoder=encoder_critic,
ensemble=Ensemble(
[
CriticHead(
input_dim=encoder_critic.output_dim
+ config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
ensemble=target_critic_heads,
output_normalization=self.normalize_targets,
)
@ -149,19 +149,9 @@ class SACPolicy(
import json
from dataclasses import asdict
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from safetensors.torch import save_file
from safetensors.torch import save_model
# NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap
# implies one side effect: the __batch_size parameters are saved in the state_dict
# __batch_size is torch.Size or safetensor save only torch.Tensor
# so we need to filter them out before saving
simplified_state_dict = {}
for name, param in self.named_parameters():
simplified_state_dict[name] = param
save_file(
simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)
)
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
# Save config
config_dict = asdict(self.config)
@ -191,7 +181,7 @@ class SACPolicy(
from pathlib import Path
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from safetensors.torch import load_file
from safetensors.torch import load_model
from lerobot.common.policies.sac.configuration_sac import SACConfig
# Check if model_id is a local path or a hub model ID
@ -243,28 +233,7 @@ class SACPolicy(
# Load state dict from safetensors file
if os.path.exists(safetensors_file):
# Note: The load_file function returns a dict with the parameters, but __batch_size
# is not loaded so we need to copy it from the model state_dict
# Load the parameters only
loaded_state_dict = load_file(safetensors_file, device=map_location)
# Copy batch size parameters
find_and_copy_params(
original_state_dict=model.state_dict(),
loaded_state_dict=loaded_state_dict,
pattern="__batch_size",
match_type="endswith",
)
# Copy normalization buffer parameters
find_and_copy_params(
original_state_dict=model.state_dict(),
loaded_state_dict=loaded_state_dict,
pattern="_orig_mod.output_normalization.buffer_action",
match_type="contains",
)
model.load_state_dict(loaded_state_dict, strict=False)
load_model(model, filename=safetensors_file, device=map_location)
return model
@ -330,7 +299,7 @@ class SACPolicy(
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
temperature = self.log_alpha.exp().item()
self.temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(
next_observations, next_observation_features
@ -358,7 +327,7 @@ class SACPolicy(
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (temperature * next_log_probs)
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
@ -398,7 +367,7 @@ class SACPolicy(
def compute_loss_actor(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
temperature = self.log_alpha.exp().item()
self.temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features)
@ -413,7 +382,7 @@ class SACPolicy(
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
@ -425,6 +394,7 @@ class MLP(nn.Module):
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.activate_final = activate_final
@ -451,11 +421,24 @@ class MLP(nn.Module):
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[i]))
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
# If we're at the final layer and a final activation is specified, use it
if (
i + 1 == len(hidden_dims)
and activate_final
and final_activation is not None
):
layers.append(
final_activation
if isinstance(final_activation, nn.Module)
else getattr(nn, final_activation)()
)
else:
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers)
@ -516,6 +499,7 @@ class CriticHead(nn.Module):
activate_final: bool = False,
dropout_rate: Optional[float] = None,
init_final: Optional[float] = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.net = MLP(
@ -524,6 +508,7 @@ class CriticHead(nn.Module):
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
if init_final is not None:
@ -578,21 +563,21 @@ class CriticEnsemble(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
ensemble: "Ensemble[CriticHead]",
ensemble: List[CriticHead],
output_normalization: nn.Module,
init_final: Optional[float] = None,
):
super().__init__()
self.encoder = encoder
self.ensemble = ensemble
self.init_final = init_final
self.output_normalization = output_normalization
self.critics = nn.ModuleList(ensemble)
self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen
if self.encoder is not None:
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
self.parameters_to_optimize += list(self.ensemble.parameters())
self.parameters_to_optimize += list(self.critics.parameters())
def forward(
self,
@ -616,8 +601,15 @@ class CriticEnsemble(nn.Module):
)
inputs = torch.cat([obs_enc, actions], dim=-1)
q_values = self.ensemble(inputs) # [num_critics, B, 1]
return q_values.squeeze(-1) # [num_critics, B]
# Loop through critics and collect outputs
q_values = []
for critic in self.critics:
q_values.append(critic(inputs))
# Stack outputs to match expected shape [num_critics, batch_size]
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
return q_values
class Policy(nn.Module):
@ -1009,73 +1001,78 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
if __name__ == "__main__":
# Test the SACObservationEncoder
# Benchmark the CriticEnsemble performance
import time
config = SACConfig()
config.num_critics = 10
config.vision_encoder_name = None
encoder = SACObservationEncoder(config, nn.Identity())
# actor_encoder = SACObservationEncoder(config)
# encoder = torch.compile(encoder)
# Configuration
num_critics = 10
batch_size = 32
action_dim = 7
obs_dim = 64
hidden_dims = [256, 256]
num_iterations = 100
print("Creating test environment...")
# Create a simple dummy encoder
class DummyEncoder(nn.Module):
def __init__(self):
super().__init__()
self.output_dim = obs_dim
self.parameters_to_optimize = []
def forward(self, obs):
# Just return a random tensor of the right shape
# In practice, this would encode the observations
return torch.randn(batch_size, obs_dim, device=device)
# Create critic heads
print(f"Creating {num_critics} critic heads...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
critic_heads = [
CriticHead(
input_dim=obs_dim + action_dim,
hidden_dims=hidden_dims,
).to(device)
for _ in range(num_critics)
]
# Create the critic ensemble
print("Creating CriticEnsemble...")
critic_ensemble = CriticEnsemble(
encoder=encoder,
ensemble=Ensemble(
[
CriticHead(
input_dim=encoder.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
encoder=DummyEncoder().to(device),
ensemble=critic_heads,
output_normalization=nn.Identity(),
)
# actor = Policy(
# encoder=actor_encoder,
# network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
# action_dim=config.output_shapes["action"][0],
# encoder_is_shared=config.shared_encoder,
# **config.policy_kwargs,
# )
# encoder = encoder.to("cuda:0")
# critic_ensemble = torch.compile(critic_ensemble)
critic_ensemble = critic_ensemble.to("cuda:0")
# actor = torch.compile(actor)
# actor = actor.to("cuda:0")
).to(device)
# Create random input data
print("Creating input data...")
obs_dict = {
"observation.image": torch.randn(8, 3, 84, 84),
"observation.state": torch.randn(8, 4),
"observation.state": torch.randn(batch_size, obs_dim, device=device),
}
actions = torch.randn(8, 2).to("cuda:0")
# obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
# print("compiling...")
q_value = critic_ensemble(obs_dict, actions)
print(q_value.size())
# action = actor(obs_dict)
# print("compiled")
# start = time.perf_counter()
# for _ in range(1000):
# # features = encoder(obs_dict)
# action = actor(obs_dict)
# # q_value = critic_ensemble(obs_dict, actions)
# print("Time taken:", time.perf_counter() - start)
# Compare the performance of the ensemble vs a for loop of 16 MLPs
ensemble = Ensemble([CriticHead(256, [256, 256]) for _ in range(2)])
ensemble = ensemble.to("cuda:0")
critic = CriticHead(256, [256, 256])
critic = critic.to("cuda:0")
data_ensemble = torch.randn(8, 256).to("cuda:0")
ensemble = torch.compile(ensemble)
# critic = torch.compile(critic)
print(ensemble(data_ensemble).size())
print(critic(data_ensemble).size())
start = time.perf_counter()
for _ in range(1000):
ensemble(data_ensemble)
print("Time taken:", time.perf_counter() - start)
start = time.perf_counter()
for _ in range(1000):
for i in range(2):
critic(data_ensemble)
print("Time taken:", time.perf_counter() - start)
actions = torch.randn(batch_size, action_dim, device=device)
# Warmup run
print("Warming up...")
_ = critic_ensemble(obs_dict, actions)
# Time the forward pass
print(f"Running benchmark with {num_iterations} iterations...")
start_time = time.perf_counter()
for _ in range(num_iterations):
q_values = critic_ensemble(obs_dict, actions)
end_time = time.perf_counter()
# Print results
elapsed_time = end_time - start_time
print(f"Total time: {elapsed_time:.4f} seconds")
print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
# Verify that all critic heads produce different outputs
# This confirms each critic head is unique
# print("\nVerifying critic outputs are different:")
# for i in range(num_critics):
# for j in range(i + 1, num_critics):
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")

View File

@ -367,7 +367,7 @@ def reset_environment(robot, events, reset_time_s):
def reset_follower_position(robot: Robot, target_position):
current_position = robot.follower_arms["main"].read("Present_Position")
trajectory = torch.from_numpy(
np.linspace(current_position, target_position, 30)
np.linspace(current_position, target_position, 50)
) # NOTE: 30 is just an aribtrary number
for pose in trajectory:
robot.send_action(pose)

View File

@ -1,6 +1,6 @@
# @package _global_
fps: 20
fps: 400
env:
name: maniskill/pushcube

View File

@ -5,26 +5,46 @@ fps: 10
env:
name: real_world
task: null
state_dim: 6
action_dim: 6
state_dim: 15
action_dim: 3
fps: ${fps}
device: mps
wrapper:
crop_params_dict:
observation.images.front: [102, 43, 358, 523]
observation.images.side: [92, 123, 379, 349]
# observation.images.front: [109, 37, 361, 557]
# observation.images.side: [94, 161, 372, 315]
observation.images.front: [171, 207, 116, 251]
observation.images.side: [232, 200, 142, 204]
resize_size: [128, 128]
control_time_s: 20
reset_follower_pos: true
control_time_s: 10
reset_follower_pos: false
use_relative_joint_positions: true
reset_time_s: 5
display_cameras: false
delta_action: 0.1
joint_masking_action_space: [1, 1, 1, 1, 0, 0] # disable wrist and gripper
delta_action: null #0.3
joint_masking_action_space: null #[1, 1, 1, 1, 0, 0] # disable wrist and gripper
add_joint_velocity_to_observation: true
add_ee_pose_to_observation: true
# If null then the teleoperation will be used to reset the robot
# Bounds for pushcube_gamepad_lerobot15 dataset and experiments
# fixed_reset_joint_positions: [-19.86, 103.19, 117.33, 42.7, 13.89, 0.297]
# ee_action_space_params: # If null then ee_action_space is not used
# bounds:
# max: [0.291, 0.147, 0.074]
# min: [0.139, -0.143, 0.03]
# Bounds for insertcube_gamepad dataset and experiments
fixed_reset_joint_positions: [20.0, 90., 90., 75., -0.7910156, -0.5673759]
ee_action_space_params:
bounds:
max: [0.25295413, 0.07498981, 0.06862044]
min: [0.2010096, -0.12, 0.0433196]
use_gamepad: true
x_step_size: 0.03
y_step_size: 0.03
z_step_size: 0.03
reward_classifier:
pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
config_path: lerobot/configs/policy/hilserl_classifier.yaml
pretrained_path: null # outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
config_path: null # lerobot/configs/policy/hilserl_classifier.yaml

View File

@ -8,22 +8,23 @@
# env.gym.obs_type=environment_state_agent_pos \
seed: 1
dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium"
# dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium"
dataset_repo_id: null
training:
# Offline training dataloader
num_workers: 4
batch_size: 512
grad_clip_norm: 10.0
grad_clip_norm: 40.0
lr: 3e-4
storage_device: "cpu"
storage_device: "cuda"
eval_freq: 2500
log_freq: 10
save_freq: 2000000
save_freq: 1000000
online_steps: 1000000
online_rollout_n_episodes: 10
@ -32,17 +33,12 @@ training:
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 200000
offline_buffer_capacity: 100000
online_buffer_seed_size: 0
online_step_before_learning: 500
do_online_rollout_async: false
policy_update_freq: 1
# delta_timestamps:
# observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
# observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
# action: "[i / ${fps} for i in range(${policy.horizon})]"
# next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
policy:
name: sac
@ -68,28 +64,33 @@ policy:
camera_number: 1
# Normalization / Unnormalization
input_normalization_modes: null
# input_normalization_modes:
# observation.state: min_max
input_normalization_params: null
# observation.state:
# min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
# 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
# -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
# -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
# 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
# input_normalization_modes: null
input_normalization_modes:
observation.state: min_max
observation.image: mean_std
# input_normalization_params: null
input_normalization_params:
observation.state:
min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
-3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
-6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
0.4001]
# max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
# 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
# 7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
# 0.4001]
observation.image:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
output_normalization_modes:
action: min_max
output_normalization_params:
action:
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
min: [-0.03, -0.03, -0.03, -0.03, -0.03, -0.03, -0.03]
max: [0.03, 0.03, 0.03, 0.03, 0.03, 0.03, 0.03]
output_normalization_shapes:
action: [7]
@ -99,8 +100,8 @@ policy:
# discount: 0.99
discount: 0.80
temperature_init: 1.0
num_critics: 10 #10
num_subsample_critics: 2
num_critics: 2 #10
num_subsample_critics: null
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
@ -111,7 +112,7 @@ policy:
actor_learner_config:
learner_host: "127.0.0.1"
learner_port: 50051
policy_parameters_push_frequency: 1
policy_parameters_push_frequency: 4
concurrency:
actor: 'processes'
learner: 'processes'
actor: 'threads'
learner: 'threads'

View File

@ -8,8 +8,7 @@
# env.gym.obs_type=environment_state_agent_pos \
seed: 1
dataset_repo_id: aractingi/push_cube_overfit_cropped_resized
#aractingi/push_cube_square_offline_demo_cropped_resized
dataset_repo_id: aractingi/insertcube_simple
training:
# Offline training dataloader
@ -30,7 +29,7 @@ training:
online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 1000000
online_buffer_capacity: 10000
online_buffer_seed_size: 0
online_step_before_learning: 100 #5000
do_online_rollout_async: false
@ -62,7 +61,7 @@ policy:
observation.images.side: [3, 128, 128]
# observation.image: [3, 128, 128]
output_shapes:
action: [4] # ["${env.action_dim}"]
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
@ -77,23 +76,16 @@ policy:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
observation.state:
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
# min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
# max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
# 6- joint positions, 6- joint velocities, 3- ee position
max: [ 52.822266, 136.14258, 142.03125, 72.1582, 22.675781, -0.5673759, 100., 100., 100., 100., 100., 100., 0.25295413, 0.07498981, 0.06862044]
min: [-2.6367188, 86.572266, 89.82422, 12.392578, -26.015625, -0.5673759, -100., -100., -100., -100., -100., -100., 0.2010096, -0.12, 0.0433196]
output_normalization_modes:
action: min_max
output_normalization_params:
# action:
# min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
# max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
action:
min: [-149.23828125, -97.734375, -100.1953125, -73.740234375]
max: [149.23828125, 97.734375, 100.1953125, 73.740234375]
min: [-0.03, -0.03, -0.01]
max: [0.03, 0.03, 0.03]
# Architecture / modeling.
# Neural networks.

View File

@ -14,9 +14,13 @@ calibration_dir: .cache/calibration/so100
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: null
joint_position_relative_bounds:
max: [ 7.2158203e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
joint_position_relative_bounds: null
# max: [100, 100, 100, 100, 100, 100]
# min: [-100, -100, -100, -100, -100, -100]
# max: [ 7.2158203e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
# min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
# max: [ 35.06836 , 103.18359 , 127.61719 , 75.58594 , 0., 0.]
# min: [ -8.876953 , 63.808594 , 90.49805 , 49.48242 , 0., 0.]
leader_arms:
main:
@ -47,13 +51,13 @@ follower_arms:
cameras:
front:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 0
camera_index: 1
fps: 30
width: 640
height: 480
side:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 1
camera_index: 0
fps: 30
width: 640
height: 480

View File

@ -54,6 +54,7 @@ from lerobot.scripts.server.network_utils import (
)
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
from lerobot.scripts.server import learner_service
from lerobot.common.robot_devices.utils import busy_wait
from torch.multiprocessing import Queue, Event
from queue import Empty
@ -312,17 +313,6 @@ def act_with_policy(
logging.info("make_policy")
# HACK: This is an ugly hack to pass the normalization parameters to the policy
# Because the action space is dynamic so we override the output normalization parameters
# it's ugly, we know ... and we will fix it
min_action_space: list = online_env.action_space.spaces[0].low.tolist()
max_action_space: list = online_env.action_space.spaces[0].high.tolist()
output_normalization_params: dict[dict[str, list]] = {
"action": {"min": min_action_space, "max": max_action_space}
}
cfg.policy.output_normalization_params = output_normalization_params
cfg.policy.output_shapes["action"] = online_env.action_space.spaces[0].shape
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
@ -347,6 +337,7 @@ def act_with_policy(
episode_intervention = False
for interaction_step in range(cfg.training.online_steps):
start_time = time.perf_counter()
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down act_with_policy")
return
@ -408,7 +399,6 @@ def act_with_policy(
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
)
)
# assign obs to the next obs and continue the rollout
obs = next_obs
@ -449,6 +439,10 @@ def act_with_policy(
episode_intervention = False
obs, info = online_env.reset()
if cfg.fps is not None:
dt_time = time.perf_counter() - start_time
busy_wait(1 / cfg.fps - dt_time)
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
"""Send transitions to learner in smaller chunks to avoid network issues.

View File

@ -263,11 +263,6 @@ if __name__ == "__main__":
with open(args.crop_params_path) as f:
rois = json.load(f)
# rois = {
# "observation.images.front": [102, 43, 358, 523],
# "observation.images.side": [92, 123, 379, 349],
# }
# Print the selected rectangular ROIs
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
for key, roi in rois.items():

View File

@ -0,0 +1,797 @@
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.scripts.server.kinematics import RobotKinematics
import logging
import time
import torch
import numpy as np
import argparse
logging.basicConfig(level=logging.INFO)
class InputController:
"""Base class for input controllers that generate motion deltas."""
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
"""
Initialize the controller.
Args:
x_step_size: Base movement step size in meters
y_step_size: Base movement step size in meters
z_step_size: Base movement step size in meters
"""
self.x_step_size = x_step_size
self.y_step_size = y_step_size
self.z_step_size = z_step_size
self.running = True
self.episode_end_status = None # None, "success", or "failure"
def start(self):
"""Start the controller and initialize resources."""
pass
def stop(self):
"""Stop the controller and release resources."""
pass
def get_deltas(self):
"""Get the current movement deltas (dx, dy, dz) in meters."""
return 0.0, 0.0, 0.0
def should_quit(self):
"""Return True if the user has requested to quit."""
return not self.running
def update(self):
"""Update controller state - call this once per frame."""
pass
def __enter__(self):
"""Support for use in 'with' statements."""
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Ensure resources are released when exiting 'with' block."""
self.stop()
def get_episode_end_status(self):
"""
Get the current episode end status.
Returns:
None if episode should continue, "success" or "failure" otherwise
"""
status = self.episode_end_status
self.episode_end_status = None # Reset after reading
return status
class KeyboardController(InputController):
"""Generate motion deltas from keyboard input."""
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
super().__init__(x_step_size, y_step_size, z_step_size)
self.key_states = {
"forward_x": False,
"backward_x": False,
"forward_y": False,
"backward_y": False,
"forward_z": False,
"backward_z": False,
"quit": False,
"success": False,
"failure": False,
}
self.listener = None
def start(self):
"""Start the keyboard listener."""
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.up:
self.key_states["forward_x"] = True
elif key == keyboard.Key.down:
self.key_states["backward_x"] = True
elif key == keyboard.Key.left:
self.key_states["forward_y"] = True
elif key == keyboard.Key.right:
self.key_states["backward_y"] = True
elif key == keyboard.Key.shift:
self.key_states["backward_z"] = True
elif key == keyboard.Key.shift_r:
self.key_states["forward_z"] = True
elif key == keyboard.Key.esc:
self.key_states["quit"] = True
self.running = False
return False
elif key == keyboard.Key.enter:
self.key_states["success"] = True
self.episode_end_status = "success"
elif key == keyboard.Key.backspace:
self.key_states["failure"] = True
self.episode_end_status = "failure"
except AttributeError:
pass
def on_release(key):
try:
if key == keyboard.Key.up:
self.key_states["forward_x"] = False
elif key == keyboard.Key.down:
self.key_states["backward_x"] = False
elif key == keyboard.Key.left:
self.key_states["forward_y"] = False
elif key == keyboard.Key.right:
self.key_states["backward_y"] = False
elif key == keyboard.Key.shift:
self.key_states["backward_z"] = False
elif key == keyboard.Key.shift_r:
self.key_states["forward_z"] = False
elif key == keyboard.Key.enter:
self.key_states["success"] = False
elif key == keyboard.Key.backspace:
self.key_states["failure"] = False
except AttributeError:
pass
self.listener = keyboard.Listener(on_press=on_press, on_release=on_release)
self.listener.start()
print("Keyboard controls:")
print(" Arrow keys: Move in X-Y plane")
print(" Shift and Shift_R: Move in Z axis")
print(" Enter: End episode with SUCCESS")
print(" Backspace: End episode with FAILURE")
print(" ESC: Exit")
def stop(self):
"""Stop the keyboard listener."""
if self.listener and self.listener.is_alive():
self.listener.stop()
def get_deltas(self):
"""Get the current movement deltas from keyboard state."""
delta_x = delta_y = delta_z = 0.0
if self.key_states["forward_x"]:
delta_x += self.x_step_size
if self.key_states["backward_x"]:
delta_x -= self.x_step_size
if self.key_states["forward_y"]:
delta_y += self.y_step_size
if self.key_states["backward_y"]:
delta_y -= self.y_step_size
if self.key_states["forward_z"]:
delta_z += self.z_step_size
if self.key_states["backward_z"]:
delta_z -= self.z_step_size
return delta_x, delta_y, delta_z
def should_quit(self):
"""Return True if ESC was pressed."""
return self.key_states["quit"]
def should_save(self):
"""Return True if Enter was pressed (save episode)."""
return self.key_states["success"] or self.key_states["failure"]
class GamepadController(InputController):
"""Generate motion deltas from gamepad input."""
def __init__(
self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1
):
super().__init__(x_step_size, y_step_size, z_step_size)
self.deadzone = deadzone
self.joystick = None
self.intervention_flag = False
def start(self):
"""Initialize pygame and the gamepad."""
import pygame
pygame.init()
pygame.joystick.init()
if pygame.joystick.get_count() == 0:
logging.error(
"No gamepad detected. Please connect a gamepad and try again."
)
self.running = False
return
self.joystick = pygame.joystick.Joystick(0)
self.joystick.init()
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
print("Gamepad controls:")
print(" Left analog stick: Move in X-Y plane")
print(" Right analog stick (vertical): Move in Z axis")
print(" B/Circle button: Exit")
print(" Y/Triangle button: End episode with SUCCESS")
print(" A/Cross button: End episode with FAILURE")
print(" X/Square button: Rerecord episode")
def stop(self):
"""Clean up pygame resources."""
import pygame
if pygame.joystick.get_init():
if self.joystick:
self.joystick.quit()
pygame.joystick.quit()
pygame.quit()
def update(self):
"""Process pygame events to get fresh gamepad readings."""
import pygame
for event in pygame.event.get():
if event.type == pygame.JOYBUTTONDOWN:
if event.button == 3:
self.episode_end_status = "success"
# A button (1) for failure
elif event.button == 1:
self.episode_end_status = "failure"
# X button (0) for rerecord
elif event.button == 0:
self.episode_end_status = "rerecord_episode"
# Reset episode status on button release
elif event.type == pygame.JOYBUTTONUP:
if event.button in [0, 2, 3]:
self.episode_end_status = None
# Check for RB button (typically button 5) for intervention flag
if self.joystick.get_button(5):
self.intervention_flag = True
else:
self.intervention_flag = False
def get_deltas(self):
"""Get the current movement deltas from gamepad state."""
import pygame
try:
# Read joystick axes
# Left stick X and Y (typically axes 0 and 1)
x_input = self.joystick.get_axis(0) # Left/Right
y_input = self.joystick.get_axis(1) # Up/Down (often inverted)
# Right stick Y (typically axis 3 or 4)
z_input = self.joystick.get_axis(3) # Up/Down for Z
# Apply deadzone to avoid drift
x_input = 0 if abs(x_input) < self.deadzone else x_input
y_input = 0 if abs(y_input) < self.deadzone else y_input
z_input = 0 if abs(z_input) < self.deadzone else z_input
# Calculate deltas (note: may need to invert axes depending on controller)
delta_x = -y_input * self.y_step_size # Forward/backward
delta_y = -x_input * self.x_step_size # Left/right
delta_z = -z_input * self.z_step_size # Up/down
return delta_x, delta_y, delta_z
except pygame.error:
logging.error("Error reading gamepad. Is it still connected?")
return 0.0, 0.0, 0.0
def should_intervene(self):
"""Return True if intervention flag was set."""
return self.intervention_flag
class GamepadControllerHID(InputController):
"""Generate motion deltas from gamepad input using HIDAPI."""
def __init__(
self,
x_step_size=0.01,
y_step_size=0.01,
z_step_size=0.01,
deadzone=0.1,
vendor_id=0x046D,
product_id=0xC219,
):
"""
Initialize the HID gamepad controller.
Args:
step_size: Base movement step size in meters
z_scale: Scaling factor for Z-axis movement
deadzone: Joystick deadzone to prevent drift
vendor_id: USB vendor ID of the gamepad (default: Logitech)
product_id: USB product ID of the gamepad (default: RumblePad 2)
"""
super().__init__(x_step_size, y_step_size, z_step_size)
self.deadzone = deadzone
self.vendor_id = vendor_id
self.product_id = product_id
self.device = None
self.device_info = None
# Movement values (normalized from -1.0 to 1.0)
self.left_x = 0.0
self.left_y = 0.0
self.right_x = 0.0
self.right_y = 0.0
# Button states
self.buttons = {}
self.quit_requested = False
self.save_requested = False
self.intervention_flag = False
def find_device(self):
"""Look for the gamepad device by vendor and product ID."""
import hid
devices = hid.enumerate()
for device in devices:
if (
device["vendor_id"] == self.vendor_id
and device["product_id"] == self.product_id
):
logging.info(
f"Found gamepad: {device.get('product_string', 'Unknown')}"
)
return device
logging.error(
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and "
f"product ID 0x{self.product_id:04X} found"
)
return None
def start(self):
"""Connect to the gamepad using HIDAPI."""
import hid
self.device_info = self.find_device()
if not self.device_info:
self.running = False
return
try:
logging.info(f"Connecting to gamepad at path: {self.device_info['path']}")
self.device = hid.device()
self.device.open_path(self.device_info["path"])
self.device.set_nonblocking(1)
manufacturer = self.device.get_manufacturer_string()
product = self.device.get_product_string()
logging.info(f"Connected to {manufacturer} {product}")
logging.info("Gamepad controls (HID mode):")
logging.info(" Left analog stick: Move in X-Y plane")
logging.info(" Right analog stick: Move in Z axis (vertical)")
logging.info(" Button 1/B/Circle: Exit")
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
logging.info(" Button 3/X/Square: End episode with FAILURE")
except OSError as e:
logging.error(f"Error opening gamepad: {e}")
logging.error(
"You might need to run this with sudo/admin privileges on some systems"
)
self.running = False
def stop(self):
"""Close the HID device connection."""
if self.device:
self.device.close()
self.device = None
def update(self):
"""
Read and process the latest gamepad data.
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
"""
for _ in range(10):
self._update()
def _update(self):
"""Read and process the latest gamepad data."""
if not self.device or not self.running:
return
try:
# Read data from the gamepad
data = self.device.read(64)
if data:
# Interpret gamepad data - this will vary by controller model
# These offsets are for the Logitech RumblePad 2
if len(data) >= 8:
# Normalize joystick values from 0-255 to -1.0-1.0
self.left_x = (data[1] - 128) / 128.0
self.left_y = (data[2] - 128) / 128.0
self.right_x = (data[3] - 128) / 128.0
self.right_y = (data[4] - 128) / 128.0
# Apply deadzone
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
self.right_x = (
0 if abs(self.right_x) < self.deadzone else self.right_x
)
self.right_y = (
0 if abs(self.right_y) < self.deadzone else self.right_y
)
# Parse button states (byte 5 in the Logitech RumblePad 2)
buttons = data[5]
# Check if RB is pressed then the intervention flag should be set
self.intervention_flag = data[6] == 2
# Check if Y/Triangle button (bit 7) is pressed for saving
# Check if X/Square button (bit 5) is pressed for failure
# Check if A/Cross button (bit 4) is pressed for rerecording
if buttons & 1 << 7:
self.episode_end_status = "success"
elif buttons & 1 << 5:
self.episode_end_status = "failure"
elif buttons & 1 << 4:
self.episode_end_status = "rerecord_episode"
else:
self.episode_end_status = None
except OSError as e:
logging.error(f"Error reading from gamepad: {e}")
def get_deltas(self):
"""Get the current movement deltas from gamepad state."""
# Calculate deltas - invert as needed based on controller orientation
delta_x = -self.left_y * self.x_step_size # Forward/backward
delta_y = -self.left_x * self.y_step_size # Left/right
delta_z = -self.right_y * self.z_step_size # Up/down
return delta_x, delta_y, delta_z
def should_quit(self):
"""Return True if quit button was pressed."""
return self.quit_requested
def should_save(self):
"""Return True if save button was pressed."""
return self.save_requested
def should_intervene(self):
"""Return True if intervention flag was set."""
return self.intervention_flag
def test_forward_kinematics(robot, fps=10):
logging.info("Testing Forward Kinematics")
timestep = time.perf_counter()
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
robot.teleop_step()
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
logging.info(f"EE Position: {ee_pos[:3,3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def test_inverse_kinematics(robot, fps=10):
logging.info("Testing Inverse Kinematics")
timestep = time.perf_counter()
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
desired_ee_pos = ee_pos
target_joint_state = RobotKinematics.ik(
joint_positions, desired_ee_pos, position_only=True
)
robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Target Joint State: {target_joint_state}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
logging.info("Testing Inverse Kinematics")
fk_func = RobotKinematics.fk_gripper_tip
timestep = time.perf_counter()
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = fk_func(joint_positions)
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
leader_ee = fk_func(leader_joint_positions)
desired_ee_pos = leader_ee
target_joint_state = RobotKinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Leader EE: {leader_ee[:3,3]}, Follower EE: {ee_pos[:3,3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
logging.info("Testing Delta End-Effector Control")
timestep = time.perf_counter()
# Initial position capture
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
fk_func = RobotKinematics.fk_gripper_tip
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
initial_leader_ee = fk_func(leader_joint_positions)
desired_ee_pos = np.diag(np.ones(4))
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
# Get leader state for teleoperation
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
leader_ee = fk_func(leader_joint_positions)
# Get current state
# obs = robot.capture_observation()
# joint_positions = obs["observation.state"].cpu().numpy()
joint_positions = robot.follower_arms["main"].read("Present_Position")
current_ee_pos = fk_func(joint_positions)
# Calculate delta between leader and follower end-effectors
# Scaling factor can be adjusted for sensitivity
scaling_factor = 1.0
ee_delta = (leader_ee - initial_leader_ee) * scaling_factor
# Apply delta to current position
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + ee_delta[0, 3]
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + ee_delta[1, 3]
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + ee_delta[2, 3]
if np.any(np.abs(ee_delta[:3, 3]) > 0.01):
# Compute joint targets via inverse kinematics
target_joint_state = RobotKinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
initial_leader_ee = leader_ee.copy()
# Send command to robot
robot.send_action(torch.from_numpy(target_joint_state))
# Logging
logging.info(
f"Current EE: {current_ee_pos[:3,3]}, Desired EE: {desired_ee_pos[:3,3]}"
)
logging.info(f"Delta EE: {ee_delta[:3,3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_delta_inverse_kinematics(
robot, controller, fps=10, bounds=None, fk_func=None
):
"""
Control a robot using delta end-effector movements from any input controller.
Args:
robot: Robot instance to control
controller: InputController instance (keyboard, gamepad, etc.)
fps: Control frequency in Hz
bounds: Optional position limits
fk_func: Forward kinematics function to use
"""
if fk_func is None:
fk_func = RobotKinematics.fk_gripper_tip
logging.info(
f"Testing Delta End-Effector Control with {controller.__class__.__name__}"
)
# Initial position capture
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
current_ee_pos = fk_func(joint_positions)
# Initialize desired position with current position
desired_ee_pos = np.eye(4) # Identity matrix
timestep = time.perf_counter()
with controller:
while not controller.should_quit() and time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
# Process input events
controller.update()
# Get currrent robot state
joint_positions = robot.follower_arms["main"].read("Present_Position")
current_ee_pos = fk_func(joint_positions)
# Get movement deltas from the controller
delta_x, delta_y, delta_z = controller.get_deltas()
# Update desired position
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + delta_x
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + delta_y
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + delta_z
# Apply bounds if provided
if bounds is not None:
desired_ee_pos[:3, 3] = np.clip(
desired_ee_pos[:3, 3], bounds["min"], bounds["max"]
)
# Only send commands if there's actual movement
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
# Compute joint targets via inverse kinematics
target_joint_state = RobotKinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
# Send command to robot
robot.send_action(torch.from_numpy(target_joint_state))
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_gym_env(env, controller, fps: int = 30):
"""
Control a robot through a gym environment using keyboard inputs.
Args:
env: A gym environment created with make_robot_env
fps: Target control frequency
"""
logging.info("Testing Keyboard Control of Gym Environment")
print("Keyboard controls:")
print(" Arrow keys: Move in X-Y plane")
print(" Shift and Shift_R: Move in Z axis")
print(" ESC: Exit")
# Reset the environment to get initial observation
obs, info = env.reset()
try:
with controller:
while not controller.should_quit():
loop_start_time = time.perf_counter()
# Process input events
controller.update()
# Get movement deltas from the controller
delta_x, delta_y, delta_z = controller.get_deltas()
# Create the action vector
action = np.array([delta_x, delta_y, delta_z])
# Skip if no movement
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
# Step the environment - pass action as a tensor with intervention flag
action_tensor = torch.from_numpy(action.astype(np.float32))
obs, reward, terminated, truncated, info = env.step(
(action_tensor, False)
)
# Log information
logging.info(
f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]"
)
logging.info(f"Reward: {reward}")
# Reset if episode ended
if terminated or truncated:
logging.info("Episode ended, resetting environment")
obs, info = env.reset()
# Maintain target frame rate
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
finally:
# Close the environment
env.close()
def make_robot_from_config(config_path, overrides=None):
"""Helper function to create a robot from a config file."""
if overrides is None:
overrides = []
robot_cfg = init_hydra_config(config_path, overrides)
return make_robot(robot_cfg)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument(
"--mode",
type=str,
default="keyboard",
choices=[
"keyboard",
"gamepad",
"keyboard_gym",
"gamepad_gym",
"leader",
"leader_abs",
],
help="Control mode to use",
)
parser.add_argument(
"--task",
type=str,
default="Robot manipulation task",
help="Description of the task being performed",
)
parser.add_argument(
"--push-to-hub",
default=True,
type=bool,
help="Push the dataset to Hugging Face Hub",
)
# Add the rest of your existing arguments
args = parser.parse_args()
robot = make_robot_from_config("lerobot/configs/robot/so100.yaml", [])
if not robot.is_connected:
robot.connect()
# Example bounds
bounds = {
"max": np.array([0.32170487, 0.201285, 0.10273342]),
"min": np.array([0.16631757, -0.08237468, 0.03364977]),
}
try:
# Determine controller type based on mode prefix
controller = None
if args.mode.startswith("keyboard"):
controller = KeyboardController(
x_step_size=0.01, y_step_size=0.01, z_step_size=0.05
)
elif args.mode.startswith("gamepad"):
controller = GamepadController(
x_step_size=0.02, y_step_size=0.02, z_step_size=0.05
)
# Handle mode categories
if args.mode in ["keyboard", "gamepad"]:
# Direct robot control modes
teleoperate_delta_inverse_kinematics(
robot, controller, bounds=bounds, fps=10
)
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
# Gym environment control modes
from lerobot.scripts.server.gym_manipulator import make_robot_env
cfg = init_hydra_config("lerobot/configs/env/so100_real.yaml", [])
cfg.env.wrapper.ee_action_space_params.use_gamepad = False
env = make_robot_env(robot, None, cfg)
teleoperate_gym_env(env, controller)
elif args.mode == "leader":
# Leader-follower modes don't use controllers
teleoperate_delta_inverse_kinematics_with_leader(robot)
elif args.mode == "leader_abs":
teleoperate_inverse_kinematics_with_leader(robot)
finally:
if robot.is_connected:
robot.disconnect()

View File

@ -7,25 +7,26 @@ import numpy as np
from lerobot.common.robot_devices.control_utils import is_headless
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.server.kinematics import RobotKinematics
def find_joint_bounds(
robot,
control_time_s=20,
control_time_s=30,
display_cameras=False,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
robot.connect()
control_time_s = float("inf")
timestamp = 0
start_episode_t = time.perf_counter()
pos_list = []
while timestamp < control_time_s:
while True:
observation, action = robot.teleop_step(record_data=True)
# Wait for 5 seconds to stabilize the robot initial position
if time.perf_counter() - start_episode_t < 5:
continue
pos_list.append(robot.follower_arms["main"].read("Present_Position"))
if display_cameras and not is_headless():
@ -36,8 +37,7 @@ def find_joint_bounds(
)
cv2.waitKey(1)
timestamp = time.perf_counter() - start_episode_t
if timestamp > 60:
if time.perf_counter() - start_episode_t > control_time_s:
max = np.max(np.stack(pos_list), 0)
min = np.min(np.stack(pos_list), 0)
print(f"Max angle position per joint {max}")
@ -45,6 +45,43 @@ def find_joint_bounds(
break
def find_ee_bounds(
robot,
control_time_s=30,
display_cameras=False,
):
if not robot.is_connected:
robot.connect()
start_episode_t = time.perf_counter()
ee_list = []
while True:
observation, action = robot.teleop_step(record_data=True)
# Wait for 5 seconds to stabilize the robot initial position
if time.perf_counter() - start_episode_t < 5:
continue
joint_positions = robot.follower_arms["main"].read("Present_Position")
print(f"Joint positions: {joint_positions}")
ee_list.append(RobotKinematics.fk_gripper_tip(joint_positions)[:3, 3])
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
if time.perf_counter() - start_episode_t > control_time_s:
max = np.max(np.stack(ee_list), 0)
min = np.min(np.stack(ee_list), 0)
print(f"Max ee position {max}")
print(f"Min ee position {min}")
break
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
@ -59,14 +96,26 @@ if __name__ == "__main__":
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument(
"--mode",
type=str,
default="joint",
choices=["joint", "ee"],
help="Mode to run the script in. Can be 'joint' or 'ee'.",
)
parser.add_argument(
"--control-time-s",
type=float,
default=20,
help="Maximum episode length in seconds",
type=int,
default=30,
help="Time step to use for control.",
)
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
robot = make_robot(robot_cfg)
find_joint_bounds(robot, control_time_s=args.control_time_s)
if args.mode == "joint":
find_joint_bounds(robot, args.control_time_s)
elif args.mode == "ee":
find_ee_bounds(robot, args.control_time_s)
if robot.is_connected:
robot.disconnect()

View File

@ -1,19 +1,26 @@
import argparse
import sys
import logging
import time
from threading import Lock
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
from typing import Annotated, Any, Dict, Tuple
import gymnasium as gym
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless
from lerobot.common.robot_devices.control_utils import (
busy_wait,
is_headless,
reset_follower_position,
)
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config, log_say
from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO)
@ -76,13 +83,19 @@ class HILSerlRobotEnv(gym.Env):
# Retrieve the size of the joint position interval bound.
self.relative_bounds_size = (
self.robot.config.joint_position_relative_bounds["max"]
- self.robot.config.joint_position_relative_bounds["min"]
(
self.robot.config.joint_position_relative_bounds["max"]
- self.robot.config.joint_position_relative_bounds["min"]
)
if self.robot.config.joint_position_relative_bounds is not None
else None
)
self.delta_relative_bounds_size = self.relative_bounds_size * self.delta
self.robot.config.max_relative_target = self.delta_relative_bounds_size.float()
self.robot.config.max_relative_target = (
self.relative_bounds_size.float()
if self.relative_bounds_size is not None
else None
)
# Dynamically configure the observation and action spaces.
self._setup_spaces()
@ -99,26 +112,23 @@ class HILSerlRobotEnv(gym.Env):
- The action space is defined as a Tuple where:
The first element is a Box space representing joint position commands. It is defined as relative (delta)
or absolute, based on the configuration.
The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation).
ThE SECONd element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation).
"""
example_obs = self.robot.capture_observation()
# Define observation spaces for images and other states.
image_keys = [key for key in example_obs if "image" in key]
state_keys = [key for key in example_obs if "image" not in key]
observation_spaces = {
key: gym.spaces.Box(
low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8
)
for key in image_keys
}
observation_spaces["observation.state"] = gym.spaces.Dict(
{
key: gym.spaces.Box(
low=0, high=10, shape=example_obs[key].shape, dtype=np.float32
)
for key in state_keys
}
observation_spaces["observation.state"] = gym.spaces.Box(
low=0,
high=10,
shape=example_obs["observation.state"].shape,
dtype=np.float32,
)
self.observation_space = gym.spaces.Dict(observation_spaces)
@ -126,20 +136,31 @@ class HILSerlRobotEnv(gym.Env):
# Define the action space for joint positions along with setting an intervention flag.
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
if self.use_delta_action_space:
bounds = (
self.relative_bounds_size
if self.relative_bounds_size is not None
else np.ones(action_dim) * 1000
)
action_space_robot = gym.spaces.Box(
low=-self.relative_bounds_size.cpu().numpy(),
high=self.relative_bounds_size.cpu().numpy(),
low=-bounds,
high=bounds,
shape=(action_dim,),
dtype=np.float32,
)
else:
bounds_min = (
self.robot.config.joint_position_relative_bounds["min"].cpu().numpy()
if self.robot.config.joint_position_relative_bounds is not None
else np.ones(action_dim) * -1000
)
bounds_max = (
self.robot.config.joint_position_relative_bounds["max"].cpu().numpy()
if self.robot.config.joint_position_relative_bounds is not None
else np.ones(action_dim) * 1000
)
action_space_robot = gym.spaces.Box(
low=self.robot.config.joint_position_relative_bounds["min"]
.cpu()
.numpy(),
high=self.robot.config.joint_position_relative_bounds["max"]
.cpu()
.numpy(),
low=bounds_min,
high=bounds_max,
shape=(action_dim,),
dtype=np.float32,
)
@ -176,7 +197,7 @@ class HILSerlRobotEnv(gym.Env):
self.current_step = 0
self.episode_data = None
return observation, {"initial_position": self.initial_follower_position}
return observation, {}
def step(
self, action: Tuple[np.ndarray, bool]
@ -218,6 +239,7 @@ class HILSerlRobotEnv(gym.Env):
policy_action = np.clip(
policy_action, self.action_space[0].low, self.action_space[0].high
)
if not intervention_bool:
if self.use_delta_action_space:
target_joint_positions = (
@ -238,8 +260,9 @@ class HILSerlRobotEnv(gym.Env):
teleop_action = (
teleop_action - self.current_joint_positions
) / self.delta
if torch.any(teleop_action < -self.relative_bounds_size) and torch.any(
teleop_action > self.relative_bounds_size
if self.relative_bounds_size is not None and (
torch.any(teleop_action < -self.relative_bounds_size)
and torch.any(teleop_action > self.relative_bounds_size)
):
logging.debug(
f"Relative teleop delta exceeded bounds {self.relative_bounds_size}, teleop_action {teleop_action}\n"
@ -299,6 +322,46 @@ class HILSerlRobotEnv(gym.Env):
self.robot.disconnect()
class AddJointVelocityToObservation(gym.ObservationWrapper):
def __init__(self, env, joint_velocity_limits=100.0, fps=30):
super().__init__(env)
# Extend observation space to include joint velocities
old_low = self.observation_space["observation.state"].low
old_high = self.observation_space["observation.state"].high
old_shape = self.observation_space["observation.state"].shape
self.last_joint_positions = np.zeros(old_shape)
new_low = np.concatenate(
[old_low, np.ones_like(old_low) * -joint_velocity_limits]
)
new_high = np.concatenate(
[old_high, np.ones_like(old_high) * joint_velocity_limits]
)
new_shape = (old_shape[0] * 2,)
self.observation_space["observation.state"] = gym.spaces.Box(
low=new_low,
high=new_high,
shape=new_shape,
dtype=np.float32,
)
self.dt = 1.0 / fps
def observation(self, observation):
joint_velocities = (
observation["observation.state"] - self.last_joint_positions
) / self.dt
self.last_joint_positions = observation["observation.state"].clone()
observation["observation.state"] = torch.cat(
[observation["observation.state"], joint_velocities], dim=-1
)
return observation
class ActionRepeatWrapper(gym.Wrapper):
def __init__(self, env, nb_repeat: int = 1):
super().__init__(env)
@ -347,8 +410,6 @@ class RewardWrapper(gym.Wrapper):
)
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
# logging.info(f"Reward: {reward}")
if reward == 1.0:
terminated = True
return observation, reward, terminated, truncated, info
@ -465,9 +526,7 @@ class TimeLimitWrapper(gym.Wrapper):
if 1.0 / time_since_last_step < self.fps:
logging.debug(f"Current timestep exceeded expected fps {self.fps}")
if self.episode_time_in_s > self.control_time_s:
# if self.current_step >= self.max_episode_steps:
# Terminated = True
if self.current_step >= self.max_episode_steps:
terminated = True
return obs, reward, terminated, truncated, info
@ -508,7 +567,20 @@ class ImageCropResizeWrapper(gym.Wrapper):
obs, reward, terminated, truncated, info = self.env.step(action)
for k in self.crop_params_dict:
device = obs[k].device
if obs[k].dim() >= 3:
# Reshape to combine height and width dimensions for easier calculation
batch_size = obs[k].size(0)
channels = obs[k].size(1)
flattened_spatial_dims = obs[k].view(batch_size, channels, -1)
# Calculate standard deviation across spatial dimensions (H, W)
std_per_channel = torch.std(flattened_spatial_dims, dim=2)
# If any channel has std=0, all pixels in that channel have the same value
if (std_per_channel <= 0.02).any():
logging.warning(
f"Potential hardware issue detected: All pixels have the same value in observation {k}"
)
# Check for NaNs before processing
if torch.isnan(obs[k]).any():
logging.error(
@ -703,19 +775,21 @@ class ResetWrapper(gym.Wrapper):
def __init__(
self,
env: HILSerlRobotEnv,
reset_fn: Optional[Callable[[], None]] = None,
reset_pose: np.ndarray | None = None,
reset_time_s: float = 5,
):
super().__init__(env)
self.reset_fn = reset_fn
self.reset_time_s = reset_time_s
self.reset_pose = reset_pose
self.robot = self.unwrapped.robot
self.init_pos = self.unwrapped.initial_follower_position
def reset(self, *, seed=None, options=None):
if self.reset_fn is not None:
self.reset_fn(self.env)
if self.reset_pose is not None:
start_time = time.perf_counter()
log_say("Reset the environment.", play_sounds=True)
reset_follower_position(self.robot, self.reset_pose)
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
log_say("Reset the environment done.", play_sounds=True)
else:
log_say(
f"Manually reset the environment for {self.reset_time_s} seconds.",
@ -741,10 +815,297 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
observation[key] = observation[key].unsqueeze(0)
if "state" in key and observation[key].dim() == 1:
observation[key] = observation[key].unsqueeze(0)
if "velocity" in key and observation[key].dim() == 1:
observation[key] = observation[key].unsqueeze(0)
return observation
# TODO: REMOVE TH
class EEActionWrapper(gym.ActionWrapper):
def __init__(self, env, ee_action_space_params=None):
super().__init__(env)
self.ee_action_space_params = ee_action_space_params
# Initialize kinematics instance for the appropriate robot type
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
self.kinematics = RobotKinematics(robot_type)
self.fk_function = self.kinematics.fk_gripper_tip
action_space_bounds = np.array(
[
ee_action_space_params.x_step_size,
ee_action_space_params.y_step_size,
ee_action_space_params.z_step_size,
]
)
ee_action_space = gym.spaces.Box(
low=-action_space_bounds,
high=action_space_bounds,
shape=(3,),
dtype=np.float32,
)
if isinstance(self.action_space, gym.spaces.Tuple):
self.action_space = gym.spaces.Tuple(
(ee_action_space, self.action_space[1])
)
else:
self.action_space = ee_action_space
self.bounds = ee_action_space_params.bounds
def action(self, action):
is_intervention = False
desired_ee_pos = np.eye(4)
if isinstance(action, tuple):
action, _ = action
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read(
"Present_Position"
)
current_ee_pos = self.fk_function(current_joint_pos)
if isinstance(action, torch.Tensor):
action = action.cpu().numpy()
desired_ee_pos[:3, 3] = np.clip(
current_ee_pos[:3, 3] + action,
self.bounds["min"],
self.bounds["max"],
)
target_joint_pos = self.kinematics.ik(
current_joint_pos,
desired_ee_pos,
position_only=True,
fk_func=self.fk_function,
)
return target_joint_pos, is_intervention
class EEObservationWrapper(gym.ObservationWrapper):
def __init__(self, env, ee_pose_limits):
super().__init__(env)
# Extend observation space to include end effector pose
prev_space = self.observation_space["observation.state"]
self.observation_space["observation.state"] = gym.spaces.Box(
low=np.concatenate([prev_space.low, ee_pose_limits["min"]]),
high=np.concatenate([prev_space.high, ee_pose_limits["max"]]),
shape=(prev_space.shape[0] + 3,),
dtype=np.float32,
)
# Initialize kinematics instance for the appropriate robot type
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
self.kinematics = RobotKinematics(robot_type)
self.fk_function = self.kinematics.fk_gripper_tip
def observation(self, observation):
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read(
"Present_Position"
)
current_ee_pos = self.fk_function(current_joint_pos)
observation["observation.state"] = torch.cat(
[
observation["observation.state"],
torch.from_numpy(current_ee_pos[:3, 3]),
],
dim=-1,
)
return observation
class GamepadControlWrapper(gym.Wrapper):
"""
Wrapper that allows controlling a gym environment with a gamepad.
This wrapper intercepts the step method and allows human input via gamepad
to override the agent's actions when desired.
"""
def __init__(
self,
env,
x_step_size=1.0,
y_step_size=1.0,
z_step_size=1.0,
auto_reset=False,
input_threshold=0.001,
):
"""
Initialize the gamepad controller wrapper.
Args:
env: The environment to wrap
x_step_size: Base movement step size for X axis in meters
y_step_size: Base movement step size for Y axis in meters
z_step_size: Base movement step size for Z axis in meters
vendor_id: USB vendor ID of the gamepad (default: Logitech)
product_id: USB product ID of the gamepad (default: RumblePad 2)
auto_reset: Whether to auto reset the environment when episode ends
input_threshold: Minimum movement delta to consider as active input
"""
super().__init__(env)
from lerobot.scripts.server.end_effector_control_utils import (
GamepadControllerHID,
GamepadController,
)
# use HidApi for macos
if sys.platform == "darwin":
self.controller = GamepadControllerHID(
x_step_size=x_step_size,
y_step_size=y_step_size,
z_step_size=z_step_size,
)
else:
self.controller = GamepadController(
x_step_size=x_step_size,
y_step_size=y_step_size,
z_step_size=z_step_size,
)
self.auto_reset = auto_reset
self.input_threshold = input_threshold
self.controller.start()
logging.info("Gamepad control wrapper initialized")
print("Gamepad controls:")
print(" Left analog stick: Move in X-Y plane")
print(" Right analog stick: Move in Z axis (up/down)")
print(" X/Square button: End episode (FAILURE)")
print(" Y/Triangle button: End episode (SUCCESS)")
print(" B/Circle button: Exit program")
def get_gamepad_action(self):
"""
Get the current action from the gamepad if any input is active.
Returns:
Tuple of (is_active, action, terminate_episode, success)
"""
# Update the controller to get fresh inputs
self.controller.update()
# Get movement deltas from the controller
delta_x, delta_y, delta_z = self.controller.get_deltas()
intervention_is_active = self.controller.should_intervene()
# Create action from gamepad input
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
# Check episode ending buttons
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
episode_end_status = self.controller.get_episode_end_status()
terminate_episode = episode_end_status is not None
success = episode_end_status == "success"
rerecord_episode = episode_end_status == "rerecord_episode"
return (
intervention_is_active,
gamepad_action,
terminate_episode,
success,
rerecord_episode,
)
def step(self, action):
"""
Step the environment, using gamepad input to override actions when active.
Args:
action: Original action from agent
Returns:
observation, reward, terminated, truncated, info
"""
# Get gamepad state and action
(
is_intervention,
gamepad_action,
terminate_episode,
success,
rerecord_episode,
) = self.get_gamepad_action()
# Update episode ending state if requested
if terminate_episode:
logging.info(
f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}"
)
# Only override the action if gamepad is active
if is_intervention:
# Format according to the expected action type
if isinstance(self.action_space, gym.spaces.Tuple):
# For environments that use (action, is_intervention) tuples
final_action = (torch.from_numpy(gamepad_action), False)
else:
final_action = torch.from_numpy(gamepad_action)
else:
# Use the original action
final_action = action
# Step the environment
obs, reward, terminated, truncated, info = self.env.step(final_action)
# Add episode ending if requested via gamepad
terminated = terminated or truncated or terminate_episode
if success:
reward = 1.0
logging.info("Episode ended successfully with reward 1.0")
info["is_intervention"] = is_intervention
action_intervention = (
final_action[0] if isinstance(final_action, Tuple) else final_action
)
if isinstance(action_intervention, np.ndarray):
action_intervention = torch.from_numpy(action_intervention)
info["action_intervention"] = action_intervention
info["rerecord_episode"] = rerecord_episode
# If episode ended, reset the state
if terminated or truncated:
# Add success/failure information to info dict
info["next.success"] = success
# Auto reset if configured
if self.auto_reset:
obs, reset_info = self.reset()
info.update(reset_info)
return obs, reward, terminated, truncated, info
def close(self):
"""Clean up resources when environment closes."""
# Stop the controller
if hasattr(self, "controller"):
self.controller.stop()
# Call the parent close method
return self.env.close()
class ActionScaleWrapper(gym.ActionWrapper):
def __init__(self, env, ee_action_space_params=None):
super().__init__(env)
assert (
ee_action_space_params is not None
), "TODO: method implemented for ee action space only so far"
self.scale_vector = np.array(
[
[
ee_action_space_params.x_step_size,
ee_action_space_params.y_step_size,
ee_action_space_params.z_step_size,
]
]
)
def action(self, action):
is_intervention = False
if isinstance(action, tuple):
action, is_intervention = action
return action * self.scale_vector, is_intervention
def make_robot_env(
@ -779,11 +1140,20 @@ def make_robot_env(
robot=robot,
display_cameras=cfg.env.wrapper.display_cameras,
delta=cfg.env.wrapper.delta_action,
use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions,
use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions
and cfg.env.wrapper.ee_action_space_params is None,
)
# Add observation and image processing
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
if cfg.env.wrapper.add_joint_velocity_to_observation:
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
if cfg.env.wrapper.add_ee_pose_to_observation:
env = EEObservationWrapper(
env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds
)
env = ConvertToLeRobotObservation(env=env, device=cfg.env.device)
if cfg.env.wrapper.crop_params_dict is not None:
env = ImageCropResizeWrapper(
env=env,
@ -792,23 +1162,44 @@ def make_robot_env(
)
# Add reward computation and control wrappers
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(
env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps
)
env = KeyboardInterfaceWrapper(env=env)
if cfg.env.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(
env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params
)
if (
cfg.env.wrapper.ee_action_space_params is not None
and cfg.env.wrapper.ee_action_space_params.use_gamepad
):
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params)
env = GamepadControlWrapper(
env=env,
x_step_size=cfg.env.wrapper.ee_action_space_params.x_step_size,
y_step_size=cfg.env.wrapper.ee_action_space_params.y_step_size,
z_step_size=cfg.env.wrapper.ee_action_space_params.z_step_size,
)
else:
env = KeyboardInterfaceWrapper(env=env)
env = ResetWrapper(
env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s
)
env = JointMaskingActionSpace(
env=env, mask=cfg.env.wrapper.joint_masking_action_space
env=env,
reset_pose=cfg.env.wrapper.fixed_reset_joint_positions,
reset_time_s=cfg.env.wrapper.reset_time_s,
)
if (
cfg.env.wrapper.ee_action_space_params is None
and cfg.env.wrapper.joint_masking_action_space is not None
):
env = JointMaskingActionSpace(
env=env, mask=cfg.env.wrapper.joint_masking_action_space
)
env = BatchCompitableWrapper(env=env)
return env
# batched version of the env that returns an observation of shape (b, c)
def get_classifier(pretrained_path, config_path, device="mps"):
if pretrained_path is None or config_path is None:
@ -834,6 +1225,134 @@ def get_classifier(pretrained_path, config_path, device="mps"):
return model
def record_dataset(
env,
repo_id,
root=None,
num_episodes=1,
control_time_s=20,
fps=30,
push_to_hub=True,
task_description="",
policy=None,
):
"""
Record a dataset of robot interactions using either a policy or teleop.
Args:
env: The environment to record from
repo_id: Repository ID for dataset storage
root: Local root directory for dataset (optional)
num_episodes: Number of episodes to record
control_time_s: Maximum episode length in seconds
fps: Frames per second for recording
push_to_hub: Whether to push dataset to Hugging Face Hub
task_description: Description of the task being recorded
policy: Optional policy to generate actions (if None, uses teleop)
"""
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
# Setup initial action (zero action if using teleop)
dummy_action = env.action_space.sample()
dummy_action = (torch.from_numpy(dummy_action[0] * 0.0), False)
action = dummy_action
# Configure dataset features based on environment spaces
features = {
"observation.state": {
"dtype": "float32",
"shape": env.observation_space["observation.state"].shape,
"names": None,
},
"action": {
"dtype": "float32",
"shape": env.action_space[0].shape,
"names": None,
},
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
}
# Add image features
for key in env.observation_space:
if "image" in key:
features[key] = {
"dtype": "video",
"shape": env.observation_space[key].shape,
"names": None,
}
# Create dataset
dataset = LeRobotDataset.create(
repo_id,
fps,
root=root,
use_videos=True,
image_writer_threads=4,
image_writer_processes=0,
features=features,
)
# Record episodes
episode_index = 0
while episode_index < num_episodes:
obs, _ = env.reset()
start_episode_t = time.perf_counter()
log_say(f"Recording episode {episode_index}", play_sounds=True)
# Run episode steps
while time.perf_counter() - start_episode_t < control_time_s:
start_loop_t = time.perf_counter()
# Get action from policy if available
if policy is not None:
action = policy.select_action(obs)
# Step environment
obs, reward, terminated, truncated, info = env.step(action)
# Check if episode needs to be rerecorded
if info.get("rerecord_episode", False):
break
# For teleop, get action from intervention
if policy is None:
action = {
"action": info["action_intervention"].cpu().squeeze(0).float()
}
# Process observation for dataset
obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()}
# Add frame to dataset
frame = {**obs, **action}
frame["next.reward"] = reward
frame["next.done"] = terminated or truncated
dataset.add_frame(frame)
# Maintain consistent timing
if fps:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
if terminated or truncated:
break
# Handle episode recording
if info.get("rerecord_episode", False):
dataset.clear_episode_buffer()
logging.info(f"Re-recording episode {episode_index}")
continue
dataset.save_episode(task_description)
episode_index += 1
# Finalize dataset
dataset.consolidate(run_compute_stats=True)
if push_to_hub:
dataset.push_to_hub(repo_id)
def replay_episode(env, repo_id, root=None, episode=0):
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@ -841,14 +1360,16 @@ def replay_episode(env, repo_id, root=None, episode=0):
dataset = LeRobotDataset(
repo_id, root=root, episodes=[episode], local_files_only=local_files_only
)
env.reset()
actions = dataset.hf_dataset.select_columns("action")
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action = actions[idx]["action"][:4]
print(action)
env.step((action / env.unwrapped.delta, False))
env.step((action, False))
# env.step((action / env.unwrapped.delta, False))
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / 10 - dt_s)
@ -875,14 +1396,6 @@ if __name__ == "__main__":
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
parser.add_argument(
"--config",
help=(
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument(
@ -929,11 +1442,30 @@ if __name__ == "__main__":
help="Repo ID of the episode to replay",
)
parser.add_argument(
"--replay-root", type=str, default=None, help="Root of the dataset to replay"
"--dataset-root", type=str, default=None, help="Root of the dataset to replay"
)
parser.add_argument(
"--replay-episode", type=int, default=0, help="Episode to replay"
)
parser.add_argument(
"--record-repo-id",
type=str,
default=None,
help="Repo ID of the dataset to record",
)
parser.add_argument(
"--record-num-episodes",
type=int,
default=1,
help="Number of episodes to record",
)
parser.add_argument(
"--record-episode-task",
type=str,
default="",
help="Single line description of the task to record",
)
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
@ -948,17 +1480,40 @@ if __name__ == "__main__":
env = make_robot_env(
robot,
reward_classifier,
cfg.env, # .wrapper,
cfg, # .wrapper,
)
env.reset()
if args.record_repo_id is not None:
policy = None
if args.pretrained_policy_name_or_path is not None:
from lerobot.common.policies.sac.modeling_sac import SACPolicy
policy = SACPolicy.from_pretrained(args.pretrained_policy_name_or_path)
policy.to(cfg.device)
policy.eval()
record_dataset(
env,
args.record_repo_id,
root=args.dataset_root,
num_episodes=args.record_num_episodes,
fps=args.fps,
task_description=args.record_episode_task,
policy=policy,
)
exit()
if args.replay_repo_id is not None:
replay_episode(
env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode
env,
args.replay_repo_id,
root=args.dataset_root,
episode=args.replay_episode,
)
exit()
env.reset()
# Retrieve the robot's action space for joint commands.
action_space_robot = env.action_space.spaces[0]
@ -967,9 +1522,11 @@ if __name__ == "__main__":
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
alpha = 0.4
alpha = 1.0
while True:
num_episode = 0
sucesses = []
while num_episode < 20:
start_loop_s = time.perf_counter()
# Sample a new random action from the robot's action space.
new_random_action = action_space_robot.sample()
@ -981,7 +1538,12 @@ if __name__ == "__main__":
(torch.from_numpy(smoothed_action), False)
)
if terminated or truncated:
sucesses.append(reward)
env.reset()
num_episode += 1
dt_s = time.perf_counter() - start_loop_s
busy_wait(1 / args.fps - dt_s)
logging.info(f"Success after 20 steps {sucesses}")
logging.info(f"success rate {sum(sucesses)/ len(sucesses)}")

View File

@ -0,0 +1,543 @@
import numpy as np
from scipy.spatial.transform import Rotation
def skew_symmetric(w):
"""Creates the skew-symmetric matrix from a 3D vector."""
return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]])
def rodrigues_rotation(w, theta):
"""Computes the rotation matrix using Rodrigues' formula."""
w_hat = skew_symmetric(w)
return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
def screw_axis_to_transform(S, theta):
"""Converts a screw axis to a 4x4 transformation matrix."""
S_w = S[:3]
S_v = S[3:]
if np.allclose(S_w, 0) and np.linalg.norm(S_v) == 1: # Pure translation
T = np.eye(4)
T[:3, 3] = S_v * theta
elif np.linalg.norm(S_w) == 1: # Rotation and translation
w_hat = skew_symmetric(S_w)
R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
t = (
np.eye(3) * theta
+ (1 - np.cos(theta)) * w_hat
+ (theta - np.sin(theta)) * w_hat @ w_hat
) @ S_v
T = np.eye(4)
T[:3, :3] = R
T[:3, 3] = t
else:
raise ValueError("Invalid screw axis parameters")
return T
def pose_difference_se3(pose1, pose2):
"""
Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices.
pose1 - pose2
Args:
pose1: A 4x4 numpy array representing the first pose.
pose2: A 4x4 numpy array representing the second pose.
Returns:
A tuple (translation_diff, rotation_diff) where:
- translation_diff is a 3x1 numpy array representing the translational difference.
- rotation_diff is a 3x1 numpy array representing the rotational difference in axis-angle representation.
"""
# Extract rotation matrices from poses
R1 = pose1[:3, :3]
R2 = pose2[:3, :3]
# Calculate translational difference
translation_diff = pose1[:3, 3] - pose2[:3, 3]
# Calculate rotational difference using scipy's Rotation library
R_diff = Rotation.from_matrix(R1 @ R2.T)
rotation_diff = R_diff.as_rotvec() # Convert to axis-angle representation
return np.concatenate([translation_diff, rotation_diff])
def se3_error(target_pose, current_pose):
pos_error = target_pose[:3, 3] - current_pose[:3, 3]
R_target = target_pose[:3, :3]
R_current = current_pose[:3, :3]
R_error = R_target @ R_current.T
rot_error = Rotation.from_matrix(R_error).as_rotvec()
return np.concatenate([pos_error, rot_error])
class RobotKinematics:
"""Robot kinematics class supporting multiple robot models."""
# Robot measurements dictionary
ROBOT_MEASUREMENTS = {
"koch": {
"gripper": [0.239, -0.001, 0.024],
"wrist": [0.209, 0, 0.024],
"forearm": [0.108, 0, 0.02],
"humerus": [0, 0, 0.036],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
"so100": {
"gripper": [0.320, 0, 0.050],
"wrist": [0.278, 0, 0.050],
"forearm": [0.143, 0, 0.044],
"humerus": [0.031, 0, 0.072],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
"moss": {
"gripper": [0.246, 0.013, 0.111],
"wrist": [0.245, 0.002, 0.064],
"forearm": [0.122, 0, 0.064],
"humerus": [0.001, 0.001, 0.063],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
}
def __init__(self, robot_type="so100"):
"""Initialize kinematics for the specified robot type.
Args:
robot_type: String specifying the robot model ("koch", "so100", or "moss")
"""
if robot_type not in self.ROBOT_MEASUREMENTS:
raise ValueError(
f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}"
)
self.robot_type = robot_type
self.measurements = self.ROBOT_MEASUREMENTS[robot_type]
# Initialize all transformation matrices and screw axes
self._setup_transforms()
def _create_translation_matrix(self, x=0, y=0, z=0):
"""Create a 4x4 translation matrix."""
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]])
def _setup_transforms(self):
"""Setup all transformation matrices and screw axes for the robot."""
# Set up rotation matrices (constant across robot types)
# Gripper orientation
self.gripper_X0 = np.array(
[
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, -1, 0, 0],
[0, 0, 0, 1],
]
)
# Wrist orientation
self.wrist_X0 = np.array(
[
[0, -1, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
)
# Base orientation
self.base_X0 = np.array(
[
[0, 0, 1, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
]
)
# Gripper
# Screw axis of gripper frame wrt base frame
self.S_BG = np.array(
[
1,
0,
0,
0,
self.measurements["gripper"][2],
-self.measurements["gripper"][1],
]
)
# Gripper origin to centroid transform
self.X_GoGc = self._create_translation_matrix(x=0.07)
# Gripper origin to tip transform
self.X_GoGt = self._create_translation_matrix(x=0.12)
# 0-position gripper frame pose wrt base
self.X_BoGo = self._create_translation_matrix(
x=self.measurements["gripper"][0],
y=self.measurements["gripper"][1],
z=self.measurements["gripper"][2],
)
# Wrist
# Screw axis of wrist frame wrt base frame
self.S_BR = np.array(
[0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]]
)
# 0-position origin to centroid transform
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
# 0-position wrist frame pose wrt base
self.X_BR = self._create_translation_matrix(
x=self.measurements["wrist"][0],
y=self.measurements["wrist"][1],
z=self.measurements["wrist"][2],
)
# Forearm
# Screw axis of forearm frame wrt base frame
self.S_BF = np.array(
[
0,
1,
0,
-self.measurements["forearm"][2],
0,
self.measurements["forearm"][0],
]
)
# Forearm origin + centroid transform
self.X_FoFc = self._create_translation_matrix(x=0.036)
# 0-position forearm frame pose wrt base
self.X_BF = self._create_translation_matrix(
x=self.measurements["forearm"][0],
y=self.measurements["forearm"][1],
z=self.measurements["forearm"][2],
)
# Humerus
# Screw axis of humerus frame wrt base frame
self.S_BH = np.array(
[
0,
-1,
0,
self.measurements["humerus"][2],
0,
-self.measurements["humerus"][0],
]
)
# Humerus origin to centroid transform
self.X_HoHc = self._create_translation_matrix(x=0.0475)
# 0-position humerus frame pose wrt base
self.X_BH = self._create_translation_matrix(
x=self.measurements["humerus"][0],
y=self.measurements["humerus"][1],
z=self.measurements["humerus"][2],
)
# Shoulder
# Screw axis of shoulder frame wrt Base frame
self.S_BS = np.array([0, 0, -1, 0, 0, 0])
# Shoulder origin to centroid transform
self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235)
# 0-position shoulder frame pose wrt base
self.X_BS = self._create_translation_matrix(
x=self.measurements["shoulder"][0],
y=self.measurements["shoulder"][1],
z=self.measurements["shoulder"][2],
)
# Base
# Base origin to centroid transform
self.X_BoBc = self._create_translation_matrix(y=0.015)
# World to base transform
self.X_WoBo = self._create_translation_matrix(
x=self.measurements["base"][0],
y=self.measurements["base"][1],
z=self.measurements["base"][2],
)
# Pre-compute gripper post-multiplication matrix
self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0
def fk_base(self):
"""Forward kinematics for the base frame."""
return self.X_WoBo @ self.X_BoBc @ self.base_X0
def fk_shoulder(self, robot_pos_deg):
"""Forward kinematics for the shoulder frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ self.X_SoSc
@ self.X_BS
)
def fk_humerus(self, robot_pos_deg):
"""Forward kinematics for the humerus frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ self.X_HoHc
@ self.X_BH
)
def fk_forearm(self, robot_pos_deg):
"""Forward kinematics for the forearm frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ self.X_FoFc
@ self.X_BF
)
def fk_wrist(self, robot_pos_deg):
"""Forward kinematics for the wrist frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ self.X_RoRc
@ self.X_BR
@ self.wrist_X0
)
def fk_gripper(self, robot_pos_deg):
"""Forward kinematics for the gripper frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
@ self._fk_gripper_post
)
def fk_gripper_tip(self, robot_pos_deg):
"""Forward kinematics for the gripper tip frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
@ self.X_GoGt
@ self.X_BoGo
@ self.gripper_X0
)
def compute_jacobian(self, robot_pos_deg, fk_func=None):
"""Finite differences to compute the Jacobian.
J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change
in the jth joint's velocity.
Args:
robot_pos_deg: Current joint positions in degrees
fk_func: Forward kinematics function to use (defaults to fk_gripper)
"""
if fk_func is None:
fk_func = self.fk_gripper
eps = 1e-8
jac = np.zeros(shape=(6, 5))
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
for el_ix in range(len(robot_pos_deg[:-1])):
delta *= 0
delta[el_ix] = eps / 2
Sdot = (
pose_difference_se3(
fk_func(robot_pos_deg[:-1] + delta),
fk_func(robot_pos_deg[:-1] - delta),
)
/ eps
)
jac[:, el_ix] = Sdot
return jac
def compute_positional_jacobian(self, robot_pos_deg, fk_func=None):
"""Finite differences to compute the positional Jacobian.
J(i, j) represents how the ith component of the end-effector's position changes wrt a small change
in the jth joint's velocity.
Args:
robot_pos_deg: Current joint positions in degrees
fk_func: Forward kinematics function to use (defaults to fk_gripper)
"""
if fk_func is None:
fk_func = self.fk_gripper
eps = 1e-8
jac = np.zeros(shape=(3, 5))
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
for el_ix in range(len(robot_pos_deg[:-1])):
delta *= 0
delta[el_ix] = eps / 2
Sdot = (
fk_func(robot_pos_deg[:-1] + delta)[:3, 3]
- fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
) / eps
jac[:, el_ix] = Sdot
return jac
def ik(
self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None
):
"""Inverse kinematics using gradient descent.
Args:
current_joint_state: Initial joint positions in degrees
desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix
position_only: If True, only match end-effector position, not orientation
fk_func: Forward kinematics function to use (defaults to fk_gripper)
Returns:
Joint positions in degrees that achieve the desired end-effector pose
"""
if fk_func is None:
fk_func = self.fk_gripper
# Do gradient descent.
max_iterations = 5
learning_rate = 1
for _ in range(max_iterations):
current_ee_pose = fk_func(current_joint_state)
if not position_only:
error = se3_error(desired_ee_pose, current_ee_pose)
jac = self.compute_jacobian(current_joint_state, fk_func)
else:
error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3]
jac = self.compute_positional_jacobian(current_joint_state, fk_func)
delta_angles = np.linalg.pinv(jac) @ error
current_joint_state[:-1] += learning_rate * delta_angles
if np.linalg.norm(error) < 5e-3:
return current_joint_state
return current_joint_state
if __name__ == "__main__":
import time
def run_test(robot_type):
"""Run test suite for a specific robot type."""
print(f"\n--- Testing {robot_type.upper()} Robot ---")
# Initialize kinematics for this robot
robot = RobotKinematics(robot_type)
# Test 1: Forward kinematics consistency
print("Test 1: Forward kinematics consistency")
test_angles = np.array(
[30, 45, -30, 20, 10, 0]
) # Example joint angles in degrees
# Calculate FK for different joints
shoulder_pose = robot.fk_shoulder(test_angles)
humerus_pose = robot.fk_humerus(test_angles)
forearm_pose = robot.fk_forearm(test_angles)
wrist_pose = robot.fk_wrist(test_angles)
gripper_pose = robot.fk_gripper(test_angles)
gripper_tip_pose = robot.fk_gripper_tip(test_angles)
# Check that poses form a consistent kinematic chain (positions should be progressively further from origin)
distances = [
np.linalg.norm(shoulder_pose[:3, 3]),
np.linalg.norm(humerus_pose[:3, 3]),
np.linalg.norm(forearm_pose[:3, 3]),
np.linalg.norm(wrist_pose[:3, 3]),
np.linalg.norm(gripper_pose[:3, 3]),
np.linalg.norm(gripper_tip_pose[:3, 3]),
]
# Check if distances generally increase along the chain
is_consistent = all(
distances[i] <= distances[i + 1] for i in range(len(distances) - 1)
)
print(f" Pose distances from origin: {[round(d, 3) for d in distances]}")
print(
f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}"
)
# Test 2: Jacobian computation
print("Test 2: Jacobian computation")
jacobian = robot.compute_jacobian(test_angles)
positional_jacobian = robot.compute_positional_jacobian(test_angles)
# Check shapes
jacobian_shape_ok = jacobian.shape == (6, 5)
pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5)
print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}")
print(
f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}"
)
# Test 3: Inverse kinematics
print("Test 3: Inverse kinematics (position only)")
# Generate target pose from known joint angles
original_angles = np.array([10, 20, 30, -10, 5, 0])
target_pose = robot.fk_gripper(original_angles)
# Start IK from a different position
initial_guess = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
# Measure IK performance
start_time = time.time()
computed_angles = robot.ik(initial_guess.copy(), target_pose)
ik_time = time.time() - start_time
# Compute resulting pose from IK solution
result_pose = robot.fk_gripper(computed_angles)
# Calculate position error
pos_error = np.linalg.norm(target_pose[:3, 3] - result_pose[:3, 3])
passed = pos_error < 0.01 # Accept errors less than 1cm
print(f" IK computation time: {ik_time:.4f} seconds")
print(f" Position error: {pos_error:.4f}")
print(f" IK position accuracy: {'PASSED' if passed else 'FAILED'}")
return is_consistent and jacobian_shape_ok and pos_jacobian_shape_ok and passed
# Run tests for all robot types
results = {}
for robot_type in ["koch", "so100", "moss"]:
results[robot_type] = run_test(robot_type)
# Print overall summary
print("\n=== Test Summary ===")
all_passed = all(results.values())
for robot_type, passed in results.items():
print(f"{robot_type.upper()}: {'PASSED' if passed else 'FAILED'}")
print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}")

View File

@ -121,7 +121,7 @@ def load_training_state(
return None, None
training_state = torch.load(
logger.last_checkpoint_dir / logger.training_state_file_name
logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False
)
if isinstance(training_state["optimizer"], dict):
@ -160,6 +160,7 @@ def initialize_replay_buffer(
optimize_memory=True,
)
logging.info("Resume training load the online dataset")
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id,
local_files_only=True,
@ -174,6 +175,38 @@ def initialize_replay_buffer(
)
def initialize_offline_replay_buffer(
cfg: DictConfig,
logger: Logger,
device: str,
storage_device: str,
active_action_dims: list[int] | None = None,
) -> ReplayBuffer:
if not cfg.resume:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
if cfg.resume:
logging.info("load offline dataset")
offline_dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id,
local_files_only=True,
root=logger.log_dir / "dataset_offline",
)
logging.info("Convert to a offline replay buffer")
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
capacity=cfg.training.offline_buffer_capacity,
)
return offline_replay_buffer
def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
@ -315,16 +348,49 @@ def start_learner_server(
def check_nan_in_transition(
observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor
):
for k in observations:
if torch.isnan(observations[k]).any():
logging.error(f"observations[{k}] contains NaN values")
for k in next_state:
if torch.isnan(next_state[k]).any():
logging.error(f"next_state[{k}] contains NaN values")
observations: torch.Tensor,
actions: torch.Tensor,
next_state: torch.Tensor,
raise_error: bool = False,
) -> bool:
"""
Check for NaN values in transition data.
Args:
observations: Dictionary of observation tensors
actions: Action tensor
next_state: Dictionary of next state tensors
raise_error: If True, raises ValueError when NaN is detected
Returns:
bool: True if NaN values were detected, False otherwise
"""
nan_detected = False
# Check observations
for key, tensor in observations.items():
if torch.isnan(tensor).any():
logging.error(f"observations[{key}] contains NaN values")
nan_detected = True
if raise_error:
raise ValueError(f"NaN detected in observations[{key}]")
# Check next state
for key, tensor in next_state.items():
if torch.isnan(tensor).any():
logging.error(f"next_state[{key}] contains NaN values")
nan_detected = True
if raise_error:
raise ValueError(f"NaN detected in next_state[{key}]")
# Check actions
if torch.isnan(actions).any():
logging.error("actions contains NaN values")
nan_detected = True
if raise_error:
raise ValueError("NaN detected in actions")
return nan_detected
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
@ -390,6 +456,10 @@ def add_actor_information_and_train(
if cfg.resume
else None,
)
# Update the policy config with the grad_clip_norm value from training config if it exists
clip_grad_norm_value = cfg.training.grad_clip_norm
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
@ -410,9 +480,6 @@ def add_actor_information_and_train(
offline_replay_buffer = None
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = None
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
@ -420,14 +487,12 @@ def add_actor_information_and_train(
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
offline_replay_buffer = initialize_offline_replay_buffer(
cfg=cfg,
logger=logger,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
active_action_dims=active_action_dims,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
@ -472,9 +537,18 @@ def add_actor_information_and_train(
for transition in transition_list:
transition = move_transition_to_device(transition, device=device)
if check_nan_in_transition(
transition["state"], transition["action"], transition["next_state"]
):
logging.warning("NaN detected in transition, skipping")
continue
replay_buffer.add(**transition)
if transition.get("complementary_info", {}).get("is_intervention"):
if cfg.dataset_repo_id is not None and transition.get(
"complementary_info", {}
).get("is_intervention"):
offline_replay_buffer.add(**transition)
logging.debug("[LEARNER] Received transitions")
logging.debug("[LEARNER] Waiting for interactions")
while not interaction_message_queue.empty() and not shutdown_event.is_set():
@ -523,6 +597,12 @@ def add_actor_information_and_train(
)
optimizers["critic"].zero_grad()
loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.critic_ensemble.parameters(), clip_grad_norm_value
)
optimizers["critic"].step()
batch = replay_buffer.sample(batch_size)
@ -557,10 +637,17 @@ def add_actor_information_and_train(
)
optimizers["critic"].zero_grad()
loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.critic_ensemble.parameters(), clip_grad_norm_value
).item()
optimizers["critic"].step()
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
training_infos["critic_grad_norm"] = critic_grad_norm
if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq):
@ -571,26 +658,48 @@ def add_actor_information_and_train(
optimizers["actor"].zero_grad()
loss_actor.backward()
# clip gradients
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.actor.parameters_to_optimize, clip_grad_norm_value
).item()
optimizers["actor"].step()
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization
loss_temperature = policy.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
# clip gradients
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
[policy.log_alpha], clip_grad_norm_value
).item()
optimizers["temperature"].step()
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue, policy)
last_time_policy_pushed = time.time()
policy.update_target_networks()
if optimization_step % log_freq == 0:
if optimization_step % cfg.training.log_freq == 0:
training_infos["replay_buffer_size"] = len(replay_buffer)
if offline_replay_buffer is not None:
training_infos["offline_replay_buffer_size"] = len(
offline_replay_buffer
)
training_infos["Optimization step"] = optimization_step
logger.log_dict(
d=training_infos, mode="train", custom_step_key="Optimization step"
@ -649,6 +758,19 @@ def add_actor_information_and_train(
replay_buffer.to_lerobot_dataset(
dataset_repo_id, fps=fps, root=logger.log_dir / "dataset"
)
if offline_replay_buffer is not None:
dataset_dir = logger.log_dir / "dataset_offline"
if dataset_dir.exists() and dataset_dir.is_dir():
shutil.rmtree(
dataset_dir,
)
offline_replay_buffer.to_lerobot_dataset(
cfg.dataset_repo_id,
fps=cfg.fps,
root=logger.log_dir / "dataset_offline",
)
logging.info("Resume training")

View File

@ -159,7 +159,7 @@ def make_maniskill(
env.unwrapped.metadata["render_fps"] = 20
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=1)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03)
return env