Change HILSerlRobotEnvConfig to inherit from EnvConfig

Added support for hil_serl classifier to be trained with train.py
run classifier training by python lerobot/scripts/train.py --policy.type=hilserl_classifier
fixes in find_joint_limits, control_robot, end_effector_control_utils
This commit is contained in:
Michel Aractingi 2025-03-27 10:23:14 +01:00 committed by AdilZouitine
parent 052a4acfc2
commit d0b7690bc0
13 changed files with 389 additions and 340 deletions

View File

@ -412,7 +412,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
names = ft["names"] names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) if names is not None and names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1]) shape = (shape[2], shape[0], shape[1])
elif key == "observation.environment_state": elif key == "observation.environment_state":
type = FeatureType.ENV type = FeatureType.ENV

View File

@ -68,4 +68,3 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
) )
return env return env

View File

@ -28,6 +28,7 @@ from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType from lerobot.configs.types import FeatureType
@ -58,6 +59,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
return SACPolicy return SACPolicy
elif name == "hilserl_classifier":
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
return Classifier
else: else:
raise NotImplementedError(f"Policy with name {name} is not implemented.") raise NotImplementedError(f"Policy with name {name} is not implemented.")
@ -73,6 +78,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return VQBeTConfig(**kwargs) return VQBeTConfig(**kwargs)
elif policy_type == "pi0": elif policy_type == "pi0":
return PI0Config(**kwargs) return PI0Config(**kwargs)
elif policy_type == "hilserl_classifier":
return ClassifierConfig(**kwargs)
else: else:
raise ValueError(f"Policy type '{policy_type}' is not available.") raise ValueError(f"Policy type '{policy_type}' is not available.")

View File

@ -1,12 +1,18 @@
import json from dataclasses import dataclass, field
import os from typing import Dict, List
from dataclasses import asdict, dataclass
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, PolicyFeature
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
@dataclass @dataclass
class ClassifierConfig: class ClassifierConfig(PreTrainedConfig):
"""Configuration for the Classifier model.""" """Configuration for the Classifier model."""
name: str = "hilserl_classifier"
num_classes: int = 2 num_classes: int = 2
hidden_dim: int = 256 hidden_dim: int = 256
dropout_rate: float = 0.1 dropout_rate: float = 0.1
@ -14,22 +20,35 @@ class ClassifierConfig:
device: str = "cpu" device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn" model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2 num_cameras: int = 2
learning_rate: float = 1e-4
normalization_mode = None
# output_features: Dict[str, PolicyFeature] = field(
# default_factory=lambda: {"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,))}
# )
def save_pretrained(self, save_dir): @property
"""Save config to json file.""" def observation_delta_indices(self) -> List | None:
os.makedirs(save_dir, exist_ok=True) return None
# Convert to dict and save as JSON @property
config_dict = asdict(self) def action_delta_indices(self) -> List | None:
with open(os.path.join(save_dir, "config.json"), "w") as f: return None
json.dump(config_dict, f, indent=2)
@classmethod @property
def from_pretrained(cls, pretrained_model_name_or_path): def reward_delta_indices(self) -> List | None:
"""Load config from json file.""" return None
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
with open(config_file) as f: def get_optimizer_preset(self) -> OptimizerConfig:
config_dict = json.load(f) return AdamWConfig(
lr=self.learning_rate,
weight_decay=0.01,
grad_clip_norm=1.0,
)
return cls(**config_dict) def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return None
def validate_features(self) -> None:
"""Validate feature configurations."""
# Classifier doesn't need specific feature validation
pass

View File

@ -1,11 +1,15 @@
import logging import logging
from typing import Optional from typing import Dict, Optional, Tuple
import torch import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn from torch import Tensor, nn
from .configuration_classifier import ClassifierConfig from lerobot.common.constants import OBS_IMAGE
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,25 +36,32 @@ class ClassifierOutput:
) )
class Classifier( class Classifier(PreTrainedPolicy):
nn.Module,
PyTorchModelHubMixin,
# Add Hub metadata
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "vision-classifier"],
):
"""Image classifier built on top of a pre-trained encoder.""" """Image classifier built on top of a pre-trained encoder."""
# Add name attribute for factory name = "hilserl_classifier"
name = "classifier" config_class = ClassifierConfig
def __init__(self, config: ClassifierConfig): def __init__(
self,
config: ClassifierConfig,
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
):
from transformers import AutoModel from transformers import AutoModel
super().__init__() super().__init__(config)
self.config = config self.config = config
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
# Initialize normalization (standardized with the policy framework)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# Set up encoder
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model # Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"): if hasattr(encoder, "vision_model"):
@ -81,8 +92,6 @@ class Classifier(
else: else:
raise ValueError("Unsupported CNN architecture") raise ValueError("Unsupported CNN architecture")
self.encoder = self.encoder.to(self.config.device)
def _freeze_encoder(self) -> None: def _freeze_encoder(self) -> None:
"""Freeze the encoder parameters.""" """Freeze the encoder parameters."""
for param in self.encoder.parameters(): for param in self.encoder.parameters():
@ -109,22 +118,13 @@ class Classifier(
1 if self.config.num_classes == 2 else self.config.num_classes, 1 if self.config.num_classes == 2 else self.config.num_classes,
), ),
) )
self.classifier_head = self.classifier_head.to(self.config.device)
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor: def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
"""Extract the appropriate output from the encoder.""" """Extract the appropriate output from the encoder."""
# Process images with the processor (handles resizing and normalization)
# processed = self.processor(
# images=x, # LeRobotDataset already provides proper tensor format
# return_tensors="pt",
# )
# processed = processed["pixel_values"].to(x.device)
processed = x
with torch.no_grad(): with torch.no_grad():
if self.is_cnn: if self.is_cnn:
# The HF ResNet applies pooling internally # The HF ResNet applies pooling internally
outputs = self.encoder(processed) outputs = self.encoder(x)
# Get pooled output directly # Get pooled output directly
features = outputs.pooler_output features = outputs.pooler_output
@ -132,14 +132,24 @@ class Classifier(
features = features.squeeze(-1).squeeze(-1) features = features.squeeze(-1).squeeze(-1)
return features return features
else: # Transformer models else: # Transformer models
outputs = self.encoder(processed) outputs = self.encoder(x)
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
return outputs.pooler_output return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :] return outputs.last_hidden_state[:, 0, :]
def forward(self, xs: torch.Tensor) -> ClassifierOutput: def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]:
"""Forward pass of the classifier.""" """Extract image tensors and label tensors from batch."""
# For training, we expect input to be a tensor directly from LeRobotDataset # Find image keys in input features
image_keys = [key for key in self.config.input_features if key.startswith(OBS_IMAGE)]
# Extract the images and labels
images = [batch[key] for key in image_keys]
labels = batch["next.reward"]
return images, labels
def predict(self, xs: list) -> ClassifierOutput:
"""Forward pass of the classifier for inference."""
encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs]) encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
logits = self.classifier_head(encoder_outputs) logits = self.classifier_head(encoder_outputs)
@ -151,10 +161,77 @@ class Classifier(
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs) return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
def predict_reward(self, x, threshold=0.6): def forward(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
"""Standard forward pass for training compatible with train.py."""
# Normalize inputs if needed
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract images and labels
images, labels = self.extract_images_and_labels(batch)
# Get predictions
outputs = self.predict(images)
# Calculate loss
if self.config.num_classes == 2: if self.config.num_classes == 2:
probs = self.forward(x).probabilities # Binary classification
loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels)
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
else:
# Multi-class classification
loss = nn.functional.cross_entropy(outputs.logits, labels.long())
predictions = torch.argmax(outputs.logits, dim=1)
# Calculate accuracy for logging
correct = (predictions == labels).sum().item()
total = labels.size(0)
accuracy = 100 * correct / total
# Return loss and metrics for logging
output_dict = {
"accuracy": accuracy,
"correct": correct,
"total": total,
}
return loss, output_dict
def predict_reward(self, batch, threshold=0.6):
"""Legacy method for compatibility."""
images, _ = self.extract_images_and_labels(batch)
if self.config.num_classes == 2:
probs = self.predict(images).probabilities
logging.debug(f"Predicted reward images: {probs}") logging.debug(f"Predicted reward images: {probs}")
return (probs > threshold).float() return (probs > threshold).float()
else: else:
return torch.argmax(self.forward(x).probabilities, dim=1) return torch.argmax(self.predict(images).probabilities, dim=1)
# Methods required by PreTrainedPolicy abstract class
def get_optim_params(self) -> dict:
"""Return optimizer parameters for the policy."""
return {
"params": self.parameters(),
"lr": getattr(self.config, "learning_rate", 1e-4),
"weight_decay": getattr(self.config, "weight_decay", 0.01),
}
def reset(self):
"""Reset any stateful components (required by PreTrainedPolicy)."""
# Classifier doesn't have stateful components that need resetting
pass
def select_action(self, batch: Dict[str, Tensor]) -> Tensor:
"""Return action (class prediction) based on input observation."""
images, _ = self.extract_images_and_labels(batch)
with torch.no_grad():
outputs = self.predict(images)
if self.config.num_classes == 2:
# For binary classification return 0 or 1
return (outputs.probabilities > 0.5).float()
else:
# For multi-class return the predicted class
return torch.argmax(outputs.probabilities, dim=1)

View File

@ -23,6 +23,7 @@ class FeatureType(str, Enum):
VISUAL = "VISUAL" VISUAL = "VISUAL"
ENV = "ENV" ENV = "ENV"
ACTION = "ACTION" ACTION = "ACTION"
REWARD = "REWARD"
class NormalizationMode(str, Enum): class NormalizationMode(str, Enum):

View File

@ -274,10 +274,7 @@ def record(
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
listener, events = init_keyboard_listener(assign_rewards=assign_rewards) listener, events = init_keyboard_listener(assign_rewards=cfg.assign_rewards)
if reset_follower:
initial_position = robot.follower_arms["main"].read("Present_Position")
# Execute a few seconds without recording to: # Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided, # 1. teleoperate the robot to move it in starting position if no policy provided,
@ -394,6 +391,7 @@ def control_robot(cfg: ControlPipelineConfig):
replay(robot, cfg.control) replay(robot, cfg.control)
elif isinstance(cfg.control, RemoteRobotConfig): elif isinstance(cfg.control, RemoteRobotConfig):
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
run_lekiwi(cfg.robot) run_lekiwi(cfg.robot)
if robot.is_connected: if robot.is_connected:

View File

@ -1,13 +1,13 @@
import argparse
import logging
import sys
import time
import numpy as np
import torch
from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.robot_devices.utils import busy_wait
from lerobot.scripts.server.kinematics import RobotKinematics from lerobot.scripts.server.kinematics import RobotKinematics
import logging
import time
import torch
import numpy as np
import argparse
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import make_robot_env, HILSerlRobotEnvConfig
from lerobot.common.robot_devices.robots.configs import RobotConfig
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -458,12 +458,13 @@ class GamepadControllerHID(InputController):
def test_forward_kinematics(robot, fps=10): def test_forward_kinematics(robot, fps=10):
logging.info("Testing Forward Kinematics") logging.info("Testing Forward Kinematics")
timestep = time.perf_counter() timestep = time.perf_counter()
kinematics = RobotKinematics(robot.robot_type)
while time.perf_counter() - timestep < 60.0: while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter() loop_start_time = time.perf_counter()
robot.teleop_step() robot.teleop_step()
obs = robot.capture_observation() obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy() joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions) ee_pos = kinematics.fk_gripper_tip(joint_positions)
logging.info(f"EE Position: {ee_pos[:3, 3]}") logging.info(f"EE Position: {ee_pos[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
@ -485,21 +486,19 @@ def test_inverse_kinematics(robot, fps=10):
def teleoperate_inverse_kinematics_with_leader(robot, fps=10): def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
logging.info("Testing Inverse Kinematics") logging.info("Testing Inverse Kinematics")
fk_func = RobotKinematics.fk_gripper_tip kinematics = RobotKinematics(robot.robot_type)
timestep = time.perf_counter() timestep = time.perf_counter()
while time.perf_counter() - timestep < 60.0: while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter() loop_start_time = time.perf_counter()
obs = robot.capture_observation() obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy() joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = fk_func(joint_positions) ee_pos = kinematics.fk_gripper_tip(joint_positions)
leader_joint_positions = robot.leader_arms["main"].read("Present_Position") leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
leader_ee = fk_func(leader_joint_positions) leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
desired_ee_pos = leader_ee desired_ee_pos = leader_ee
target_joint_state = RobotKinematics.ik( target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
robot.send_action(torch.from_numpy(target_joint_state)) robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}") logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
@ -513,10 +512,10 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
obs = robot.capture_observation() obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy() joint_positions = obs["observation.state"].cpu().numpy()
fk_func = RobotKinematics.fk_gripper_tip kinematics = RobotKinematics(robot.robot_type)
leader_joint_positions = robot.leader_arms["main"].read("Present_Position") leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
initial_leader_ee = fk_func(leader_joint_positions) initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
desired_ee_pos = np.diag(np.ones(4)) desired_ee_pos = np.diag(np.ones(4))
@ -525,13 +524,13 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
# Get leader state for teleoperation # Get leader state for teleoperation
leader_joint_positions = robot.leader_arms["main"].read("Present_Position") leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
leader_ee = fk_func(leader_joint_positions) leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
# Get current state # Get current state
# obs = robot.capture_observation() # obs = robot.capture_observation()
# joint_positions = obs["observation.state"].cpu().numpy() # joint_positions = obs["observation.state"].cpu().numpy()
joint_positions = robot.follower_arms["main"].read("Present_Position") joint_positions = robot.follower_arms["main"].read("Present_Position")
current_ee_pos = fk_func(joint_positions) current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Calculate delta between leader and follower end-effectors # Calculate delta between leader and follower end-effectors
# Scaling factor can be adjusted for sensitivity # Scaling factor can be adjusted for sensitivity
@ -545,9 +544,7 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
if np.any(np.abs(ee_delta[:3, 3]) > 0.01): if np.any(np.abs(ee_delta[:3, 3]) > 0.01):
# Compute joint targets via inverse kinematics # Compute joint targets via inverse kinematics
target_joint_state = RobotKinematics.ik( target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
initial_leader_ee = leader_ee.copy() initial_leader_ee = leader_ee.copy()
@ -580,7 +577,8 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None,
# Initial position capture # Initial position capture
obs = robot.capture_observation() obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy() joint_positions = obs["observation.state"].cpu().numpy()
current_ee_pos = fk_func(joint_positions) kinematics = RobotKinematics(robot.robot_type)
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Initialize desired position with current position # Initialize desired position with current position
desired_ee_pos = np.eye(4) # Identity matrix desired_ee_pos = np.eye(4) # Identity matrix
@ -595,7 +593,7 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None,
# Get currrent robot state # Get currrent robot state
joint_positions = robot.follower_arms["main"].read("Present_Position") joint_positions = robot.follower_arms["main"].read("Present_Position")
current_ee_pos = fk_func(joint_positions) current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Get movement deltas from the controller # Get movement deltas from the controller
delta_x, delta_y, delta_z = controller.get_deltas() delta_x, delta_y, delta_z = controller.get_deltas()
@ -612,9 +610,7 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None,
# Only send commands if there's actual movement # Only send commands if there's actual movement
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]): if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
# Compute joint targets via inverse kinematics # Compute joint targets via inverse kinematics
target_joint_state = RobotKinematics.ik( target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
# Send command to robot # Send command to robot
robot.send_action(torch.from_numpy(target_joint_state)) robot.send_action(torch.from_numpy(target_joint_state))
@ -676,7 +672,17 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
# Close the environment # Close the environment
env.close() env.close()
if __name__ == "__main__": if __name__ == "__main__":
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import (
EEActionSpaceConfig,
EnvWrapperConfig,
HILSerlRobotEnvConfig,
make_robot_env,
)
parser = argparse.ArgumentParser(description="Test end-effector control") parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument( parser.add_argument(
"--mode", "--mode",
@ -698,12 +704,6 @@ if __name__ == "__main__":
default="so100", default="so100",
help="Robot type (so100, koch, aloha, etc.)", help="Robot type (so100, koch, aloha, etc.)",
) )
parser.add_argument(
"--config-path",
type=str,
default=None,
help="Path to the config file in json format",
)
args = parser.parse_args() args = parser.parse_args()
@ -725,7 +725,10 @@ if __name__ == "__main__":
if args.mode.startswith("keyboard"): if args.mode.startswith("keyboard"):
controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05) controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
elif args.mode.startswith("gamepad"): elif args.mode.startswith("gamepad"):
controller = GamepadController(x_step_size=0.02, y_step_size=0.02, z_step_size=0.05) if sys.platform == "darwin":
controller = GamepadControllerHID(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
else:
controller = GamepadController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
# Handle mode categories # Handle mode categories
if args.mode in ["keyboard", "gamepad"]: if args.mode in ["keyboard", "gamepad"]:
@ -734,12 +737,14 @@ if __name__ == "__main__":
elif args.mode in ["keyboard_gym", "gamepad_gym"]: elif args.mode in ["keyboard_gym", "gamepad_gym"]:
# Gym environment control modes # Gym environment control modes
cfg = HILSerlRobotEnvConfig() cfg = HILSerlRobotEnvConfig(robot=robot_config, wrapper=EnvWrapperConfig())
if args.config_path is not None: cfg.wrapper.ee_action_space_params = EEActionSpaceConfig(
cfg = HILSerlRobotEnvConfig.from_json(args.config_path) x_step_size=0.03, y_step_size=0.03, z_step_size=0.03, bounds=bounds
)
cfg.wrapper.ee_action_space_params.use_gamepad = False
cfg.device = "cpu"
env = make_robot_env(cfg, robot) env = make_robot_env(cfg, robot)
teleoperate_gym_env(env, controller, fps=args.fps) teleoperate_gym_env(env, controller, fps=cfg.fps)
elif args.mode == "leader": elif args.mode == "leader":
# Leader-follower modes don't use controllers # Leader-follower modes don't use controllers

View File

@ -63,9 +63,10 @@ def find_ee_bounds(
if time.perf_counter() - start_episode_t < 5: if time.perf_counter() - start_episode_t < 5:
continue continue
kinematics = RobotKinematics(robot.robot_type)
joint_positions = robot.follower_arms["main"].read("Present_Position") joint_positions = robot.follower_arms["main"].read("Present_Position")
print(f"Joint positions: {joint_positions}") print(f"Joint positions: {joint_positions}")
ee_list.append(RobotKinematics.fk_gripper_tip(joint_positions)[:3, 3]) ee_list.append(kinematics.fk_gripper_tip(joint_positions)[:3, 3])
if display_cameras and not is_headless(): if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key] image_keys = [key for key in observation if "image" in key]
@ -81,20 +82,19 @@ def find_ee_bounds(
break break
def make_robot(robot_type="so100", mock=True): def make_robot(robot_type="so100"):
""" """
Create a robot instance using the appropriate robot config class. Create a robot instance using the appropriate robot config class.
Args: Args:
robot_type: Robot type string (e.g., "so100", "koch", "aloha") robot_type: Robot type string (e.g., "so100", "koch", "aloha")
mock: Whether to use mock mode for hardware (default: True)
Returns: Returns:
Robot instance Robot instance
""" """
# Get the appropriate robot config class based on robot_type # Get the appropriate robot config class based on robot_type
robot_config = RobotConfig.get_choice_class(robot_type)(mock=mock) robot_config = RobotConfig.get_choice_class(robot_type)(mock=False)
robot_config.leader_arms["main"].port = leader_port robot_config.leader_arms["main"].port = leader_port
robot_config.follower_arms["main"].port = follower_port robot_config.follower_arms["main"].port = follower_port
@ -122,18 +122,12 @@ if __name__ == "__main__":
default="so100", default="so100",
help="Robot type (so100, koch, aloha, etc.)", help="Robot type (so100, koch, aloha, etc.)",
) )
parser.add_argument(
"--mock",
type=int,
default=1,
help="Use mock mode for hardware simulation",
)
# Only parse known args, leaving robot config args for Hydra if used # Only parse known args, leaving robot config args for Hydra if used
args, _ = parser.parse_known_args() args = parser.parse_args()
# Create robot with the appropriate config # Create robot with the appropriate config
robot = make_robot(args.robot_type, args.mock) robot = make_robot(args.robot_type)
if args.mode == "joint": if args.mode == "joint":
find_joint_bounds(robot, args.control_time_s) find_joint_bounds(robot, args.control_time_s)

View File

@ -1,45 +1,37 @@
import logging import logging
import sys import sys
import time import time
import sys from dataclasses import dataclass
from threading import Lock from threading import Lock
from typing import Annotated, Any, Dict, Tuple from typing import Annotated, Any, Dict, Optional, Tuple
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
import torch import torch
import torchvision.transforms.functional as F # noqa: N812 import torchvision.transforms.functional as F # noqa: N812
import json
from dataclasses import dataclass
from lerobot.common.envs.utils import preprocess_observation
from lerobot.configs.train import TrainPipelineConfig
from lerobot.common.envs.configs import EnvConfig from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.robot_devices.control_utils import ( from lerobot.common.robot_devices.control_utils import (
busy_wait, busy_wait,
is_headless, is_headless,
# reset_follower_position, reset_follower_position,
) )
from typing import Optional
from lerobot.common.utils.utils import log_say
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.kinematics import RobotKinematics from lerobot.common.utils.utils import log_say
from lerobot.scripts.server.maniskill_manipulator import ManiskillEnvConfig, make_maniskill
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@dataclass @dataclass
class EEActionSpaceConfig: class EEActionSpaceConfig:
"""Configuration parameters for end-effector action space.""" """Configuration parameters for end-effector action space."""
x_step_size: float x_step_size: float
y_step_size: float y_step_size: float
z_step_size: float z_step_size: float
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
use_gamepad: bool = False use_gamepad: bool = False
@ -48,6 +40,7 @@ class EEActionSpaceConfig:
@dataclass @dataclass
class EnvWrapperConfig: class EnvWrapperConfig:
"""Configuration for environment wrappers.""" """Configuration for environment wrappers."""
display_cameras: bool = False display_cameras: bool = False
delta_action: float = 0.1 delta_action: float = 0.1
use_relative_joint_positions: bool = True use_relative_joint_positions: bool = True
@ -64,28 +57,27 @@ class EnvWrapperConfig:
reward_classifier_config_file: Optional[str] = None reward_classifier_config_file: Optional[str] = None
@EnvConfig.register_subclass(name="gym_manipulator")
@dataclass @dataclass
class HILSerlRobotEnvConfig: class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment.""" """Configuration for the HILSerlRobotEnv environment."""
robot: RobotConfig
wrapper: EnvWrapperConfig robot: Optional[RobotConfig] = None
env_name: str = "real_robot" wrapper: Optional[EnvWrapperConfig] = None
fps: int = 10 fps: int = 10
mode: str = None # Either "record", "replay", None mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None repo_id: Optional[str] = None
dataset_root: Optional[str] = None dataset_root: Optional[str] = None
task: str = "" task: str = ""
num_episodes: int = 10 # only for record mode num_episodes: int = 10 # only for record mode
episode: int = 0 episode: int = 0
device: str = "cuda" device: str = "cuda"
push_to_hub: bool = True push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None pretrained_policy_name_or_path: Optional[str] = None
@classmethod def gym_kwargs(self) -> dict:
def from_json(cls, json_path: str): return {}
with open(json_path, "r") as f:
config = json.load(f)
return cls(**config)
class HILSerlRobotEnv(gym.Env): class HILSerlRobotEnv(gym.Env):
""" """
@ -580,8 +572,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
if key_crop not in self.env.observation_space.keys(): # noqa: SIM118 if key_crop not in self.env.observation_space.keys(): # noqa: SIM118
raise ValueError(f"Key {key_crop} not in observation space") raise ValueError(f"Key {key_crop} not in observation space")
for key in crop_params_dict: for key in crop_params_dict:
top, left, height, width = crop_params_dict[key] new_shape = (3, resize_size[0], resize_size[1])
new_shape = (top + height, left + width)
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape) self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
self.resize_size = resize_size self.resize_size = resize_size
@ -1097,9 +1088,7 @@ class ActionScaleWrapper(gym.ActionWrapper):
return action * self.scale_vector, is_intervention return action * self.scale_vector, is_intervention
def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv: def make_robot_env(cfg) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv:
""" """
Factory function to create a vectorized robot environment. Factory function to create a vectorized robot environment.
@ -1111,16 +1100,16 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
Returns: Returns:
A vectorized gym environment with all the necessary wrappers applied. A vectorized gym environment with all the necessary wrappers applied.
""" """
if "maniskill" in cfg.name: # if "maniskill" in cfg.name:
from lerobot.scripts.server.maniskill_manipulator import make_maniskill # from lerobot.scripts.server.maniskill_manipulator import make_maniskill
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") # logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
env = make_maniskill( # env = make_maniskill(
cfg=cfg, # cfg=cfg,
n_envs=1, # n_envs=1,
) # )
return env # return env
robot = cfg.robot robot = make_robot_from_config(cfg.robot)
# Create base environment # Create base environment
env = HILSerlRobotEnv( env = HILSerlRobotEnv(
robot=robot, robot=robot,
@ -1150,10 +1139,7 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.ee_action_space_params is not None: if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
if ( if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
cfg.wrapper.ee_action_space_params is not None
and cfg.wrapper.ee_action_space_params.use_gamepad
):
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
env = GamepadControlWrapper( env = GamepadControlWrapper(
env=env, env=env,
@ -1169,10 +1155,7 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
reset_pose=cfg.wrapper.fixed_reset_joint_positions, reset_pose=cfg.wrapper.fixed_reset_joint_positions,
reset_time_s=cfg.wrapper.reset_time_s, reset_time_s=cfg.wrapper.reset_time_s,
) )
if ( if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
cfg.wrapper.ee_action_space_params is None
and cfg.wrapper.joint_masking_action_space is not None
):
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
env = BatchCompitableWrapper(env=env) env = BatchCompitableWrapper(env=env)
@ -1180,7 +1163,10 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
def get_classifier(cfg): def get_classifier(cfg):
if cfg.wrapper.reward_classifier_pretrained_path is None or cfg.wrapper.reward_classifier_config_file is None: if (
cfg.wrapper.reward_classifier_pretrained_path is None
or cfg.wrapper.reward_classifier_config_file is None
):
return None return None
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ( from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
@ -1258,7 +1244,8 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
# Record episodes # Record episodes
episode_index = 0 episode_index = 0
while episode_index < cfg.record_num_episodes: recorded_action = None
while episode_index < cfg.num_episodes:
obs, _ = env.reset() obs, _ = env.reset()
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
log_say(f"Recording episode {episode_index}", play_sounds=True) log_say(f"Recording episode {episode_index}", play_sounds=True)
@ -1279,16 +1266,19 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
break break
# For teleop, get action from intervention # For teleop, get action from intervention
if policy is None: recorded_action = {
action = {"action": info["action_intervention"].cpu().squeeze(0).float()} "action": info["action_intervention"].cpu().squeeze(0).float() if policy is None else action
}
# Process observation for dataset # Process observation for dataset
obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()} obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()}
obs["observation.images.side"] = torch.clamp(obs["observation.images.side"], 0, 1)
# Add frame to dataset # Add frame to dataset
frame = {**obs, **action} frame = {**obs, **recorded_action}
frame["next.reward"] = reward frame["next.reward"] = np.array([reward], dtype=np.float32)
frame["next.done"] = terminated or truncated frame["next.done"] = np.array([terminated or truncated], dtype=bool)
frame["task"] = cfg.task
dataset.add_frame(frame) dataset.add_frame(frame)
# Maintain consistent timing # Maintain consistent timing
@ -1309,9 +1299,9 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
episode_index += 1 episode_index += 1
# Finalize dataset # Finalize dataset
dataset.consolidate(run_compute_stats=True) # dataset.consolidate(run_compute_stats=True)
if cfg.push_to_hub: if cfg.push_to_hub:
dataset.push_to_hub(cfg.repo_id) dataset.push_to_hub()
def replay_episode(env, repo_id, root=None, episode=0): def replay_episode(env, repo_id, root=None, episode=0):
@ -1334,82 +1324,69 @@ def replay_episode(env, repo_id, root=None, episode=0):
busy_wait(1 / 10 - dt_s) busy_wait(1 / 10 - dt_s)
# @parser.wrap() @parser.wrap()
# def main(cfg): def main(cfg: EnvConfig):
env = make_robot_env(cfg)
# robot = make_robot_from_config(cfg.robot) if cfg.mode == "record":
policy = None
if cfg.pretrained_policy_name_or_path is not None:
from lerobot.common.policies.sac.modeling_sac import SACPolicy
# reward_classifier = None #get_classifier( policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
# # cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file policy.to(cfg.device)
# # ) policy.eval()
# user_relative_joint_positions = True
# env = make_robot_env(cfg, robot) record_dataset(
env,
policy=None,
cfg=cfg,
)
exit()
# if cfg.mode == "record": if cfg.mode == "replay":
# policy = None replay_episode(
# if cfg.pretrained_policy_name_or_path is not None: env,
# from lerobot.common.policies.sac.modeling_sac import SACPolicy cfg.replay_repo_id,
root=cfg.dataset_root,
episode=cfg.replay_episode,
)
exit()
# policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path) env.reset()
# policy.to(cfg.device)
# policy.eval()
# record_dataset( # Retrieve the robot's action space for joint commands.
# env, action_space_robot = env.action_space.spaces[0]
# cfg.repo_id,
# root=cfg.dataset_root,
# num_episodes=cfg.num_episodes,
# fps=cfg.fps,
# task_description=cfg.task,
# policy=policy,
# )
# exit()
# if cfg.mode == "replay": # Initialize the smoothed action as a random sample.
# replay_episode( smoothed_action = action_space_robot.sample()
# env,
# cfg.replay_repo_id,
# root=cfg.dataset_root,
# episode=cfg.replay_episode,
# )
# exit()
# env.reset() # Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
alpha = 1.0
# # Retrieve the robot's action space for joint commands. num_episode = 0
# action_space_robot = env.action_space.spaces[0] sucesses = []
while num_episode < 20:
start_loop_s = time.perf_counter()
# Sample a new random action from the robot's action space.
new_random_action = action_space_robot.sample()
# Update the smoothed action using an exponential moving average.
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
# # Initialize the smoothed action as a random sample. # Execute the step: wrap the NumPy action in a torch tensor.
# smoothed_action = action_space_robot.sample() obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
if terminated or truncated:
sucesses.append(reward)
env.reset()
num_episode += 1
# # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. dt_s = time.perf_counter() - start_loop_s
# # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. busy_wait(1 / cfg.fps - dt_s)
# alpha = 1.0
# num_episode = 0 logging.info(f"Success after 20 steps {sucesses}")
# sucesses = [] logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
# while num_episode < 20:
# start_loop_s = time.perf_counter()
# # Sample a new random action from the robot's action space.
# new_random_action = action_space_robot.sample()
# # Update the smoothed action using an exponential moving average.
# smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
# # Execute the step: wrap the NumPy action in a torch tensor.
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
# if terminated or truncated:
# sucesses.append(reward)
# env.reset()
# num_episode += 1
# dt_s = time.perf_counter() - start_loop_s
# busy_wait(1 / cfg.fps - dt_s)
# logging.info(f"Success after 20 steps {sucesses}")
# logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
# if __name__ == "__main__":
# main()
if __name__ == "__main__": if __name__ == "__main__":
make_robot_env() main()

View File

@ -15,12 +15,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import os
import shutil import shutil
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pprint import pformat
import os
from pathlib import Path from pathlib import Path
from pprint import pformat
import draccus import draccus
import grpc import grpc
@ -30,35 +30,42 @@ import hilserl_pb2_grpc # type: ignore
import torch import torch
from termcolor import colored from termcolor import colored
from torch import nn from torch import nn
from torch.multiprocessing import Queue from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs import parser
# TODO: Remove the import of maniskill # TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy, SACConfig from lerobot.common.policies.sac.modeling_sac import SACConfig, SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.train_utils import ( from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir, get_step_checkpoint_dir,
get_step_identifier, get_step_identifier,
load_training_state as utils_load_training_state,
save_checkpoint, save_checkpoint,
update_last_checkpoint,
save_training_state, save_training_state,
update_last_checkpoint,
)
from lerobot.common.utils.train_utils import (
load_training_state as utils_load_training_state,
) )
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device, get_safe_torch_device,
init_logging, init_logging,
) )
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.wandb_utils import WandBLogger from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import learner_service from lerobot.scripts.server import learner_service
from lerobot.scripts.server.buffer import ( from lerobot.scripts.server.buffer import (
ReplayBuffer, ReplayBuffer,
@ -70,47 +77,39 @@ from lerobot.scripts.server.buffer import (
state_to_bytes, state_to_bytes,
) )
from lerobot.scripts.server.utils import setup_process_handlers from lerobot.scripts.server.utils import setup_process_handlers
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig: def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
""" """
Handle the resume logic for training. Handle the resume logic for training.
If resume is True: If resume is True:
- Verifies that a checkpoint exists - Verifies that a checkpoint exists
- Loads the checkpoint configuration - Loads the checkpoint configuration
- Logs resumption details - Logs resumption details
- Returns the checkpoint configuration - Returns the checkpoint configuration
If resume is False: If resume is False:
- Checks if an output directory exists (to prevent accidental overwriting) - Checks if an output directory exists (to prevent accidental overwriting)
- Returns the original configuration - Returns the original configuration
Args: Args:
cfg (TrainPipelineConfig): The training configuration cfg (TrainPipelineConfig): The training configuration
Returns: Returns:
TrainPipelineConfig: The updated configuration TrainPipelineConfig: The updated configuration
Raises: Raises:
RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists
""" """
out_dir = cfg.output_dir out_dir = cfg.output_dir
# Case 1: Not resuming, but need to check if directory exists to prevent overwrites # Case 1: Not resuming, but need to check if directory exists to prevent overwrites
if not cfg.resume: if not cfg.resume:
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
if os.path.exists(checkpoint_dir): if os.path.exists(checkpoint_dir):
raise RuntimeError( raise RuntimeError(
f"Output directory {checkpoint_dir} already exists. " f"Output directory {checkpoint_dir} already exists. Use `resume=true` to resume training."
"Use `resume=true` to resume training."
) )
return cfg return cfg
@ -131,7 +130,7 @@ def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
# Load config using Draccus # Load config using Draccus
checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json") checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json")
checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path) checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path)
# Ensure resume flag is set in returned config # Ensure resume flag is set in returned config
checkpoint_cfg.resume = True checkpoint_cfg.resume = True
return checkpoint_cfg return checkpoint_cfg
@ -143,11 +142,11 @@ def load_training_state(
): ):
""" """
Loads the training state (optimizers, step count, etc.) from a checkpoint. Loads the training state (optimizers, step count, etc.) from a checkpoint.
Args: Args:
cfg (TrainPipelineConfig): Training configuration cfg (TrainPipelineConfig): Training configuration
optimizers (Optimizer | dict): Optimizers to load state into optimizers (Optimizer | dict): Optimizers to load state into
Returns: Returns:
tuple: (optimization_step, interaction_step) or (None, None) if not resuming tuple: (optimization_step, interaction_step) or (None, None) if not resuming
""" """
@ -156,23 +155,23 @@ def load_training_state(
# Construct path to the last checkpoint directory # Construct path to the last checkpoint directory
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
logging.info(f"Loading training state from {checkpoint_dir}") logging.info(f"Loading training state from {checkpoint_dir}")
try: try:
# Use the utility function from train_utils which loads the optimizer state # Use the utility function from train_utils which loads the optimizer state
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None) step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
# Load interaction step separately from training_state.pt # Load interaction step separately from training_state.pt
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt") training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
interaction_step = 0 interaction_step = 0
if os.path.exists(training_state_path): if os.path.exists(training_state_path):
training_state = torch.load(training_state_path, weights_only=False) training_state = torch.load(training_state_path, weights_only=False)
interaction_step = training_state.get("interaction_step", 0) interaction_step = training_state.get("interaction_step", 0)
logging.info(f"Resuming from step {step}, interaction step {interaction_step}") logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
return step, interaction_step return step, interaction_step
except Exception as e: except Exception as e:
logging.error(f"Failed to load training state: {e}") logging.error(f"Failed to load training state: {e}")
return None, None return None, None
@ -181,7 +180,7 @@ def load_training_state(
def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None: def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
""" """
Log information about the training process. Log information about the training process.
Args: Args:
cfg (TrainPipelineConfig): Training configuration cfg (TrainPipelineConfig): Training configuration
policy (nn.Module): Policy model policy (nn.Module): Policy model
@ -189,7 +188,6 @@ def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.policy.online_steps=}") logging.info(f"{cfg.policy.online_steps=}")
@ -197,19 +195,15 @@ def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer( def initialize_replay_buffer(cfg: TrainPipelineConfig, device: str, storage_device: str) -> ReplayBuffer:
cfg: TrainPipelineConfig,
device: str,
storage_device: str
) -> ReplayBuffer:
""" """
Initialize a replay buffer, either empty or from a dataset if resuming. Initialize a replay buffer, either empty or from a dataset if resuming.
Args: Args:
cfg (TrainPipelineConfig): Training configuration cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on device (str): Device to store tensors on
storage_device (str): Device for storage optimization storage_device (str): Device for storage optimization
Returns: Returns:
ReplayBuffer: Initialized replay buffer ReplayBuffer: Initialized replay buffer
""" """
@ -224,7 +218,7 @@ def initialize_replay_buffer(
logging.info("Resume training load the online dataset") logging.info("Resume training load the online dataset")
dataset_path = os.path.join(cfg.output_dir, "dataset") dataset_path = os.path.join(cfg.output_dir, "dataset")
# NOTE: In RL is possible to not have a dataset. # NOTE: In RL is possible to not have a dataset.
repo_id = None repo_id = None
if cfg.dataset is not None: if cfg.dataset is not None:
@ -250,13 +244,13 @@ def initialize_offline_replay_buffer(
) -> ReplayBuffer: ) -> ReplayBuffer:
""" """
Initialize an offline replay buffer from a dataset. Initialize an offline replay buffer from a dataset.
Args: Args:
cfg (TrainPipelineConfig): Training configuration cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on device (str): Device to store tensors on
storage_device (str): Device for storage optimization storage_device (str): Device for storage optimization
active_action_dims (list[int] | None): Active action dimensions for masking active_action_dims (list[int] | None): Active action dimensions for masking
Returns: Returns:
ReplayBuffer: Initialized offline replay buffer ReplayBuffer: Initialized offline replay buffer
""" """
@ -314,7 +308,7 @@ def start_learner_threads(
) -> None: ) -> None:
""" """
Start the learner threads for training. Start the learner threads for training.
Args: Args:
cfg (TrainPipelineConfig): Training configuration cfg (TrainPipelineConfig): Training configuration
wandb_logger (WandBLogger | None): Logger for metrics wandb_logger (WandBLogger | None): Logger for metrics
@ -512,17 +506,19 @@ def add_actor_information_and_train(
logging.info("Initializing policy") logging.info("Initializing policy")
# Get checkpoint dir for resuming # Get checkpoint dir for resuming
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None checkpoint_dir = (
os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
)
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
policy: SACPolicy = make_policy( policy: SACPolicy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
# ds_meta=cfg.dataset, # ds_meta=cfg.dataset,
env_cfg=cfg.env env_cfg=cfg.env,
) )
# Update the policy config with the grad_clip_norm value from training config if it exists # Update the policy config with the grad_clip_norm value from training config if it exists
clip_grad_norm_value:float = cfg.policy.grad_clip_norm clip_grad_norm_value: float = cfg.policy.grad_clip_norm
# compile policy # compile policy
policy = torch.compile(policy) policy = torch.compile(policy)
@ -536,7 +532,7 @@ def add_actor_information_and_train(
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy) optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
log_training_info(cfg=cfg, policy= policy) log_training_info(cfg=cfg, policy=policy)
replay_buffer = initialize_replay_buffer(cfg, device, storage_device) replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
batch_size = cfg.batch_size batch_size = cfg.batch_size
@ -615,14 +611,10 @@ def add_actor_information_and_train(
interaction_message = bytes_to_python_object(interaction_message) interaction_message = bytes_to_python_object(interaction_message)
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift interaction_message["Interaction step"] += interaction_step_shift
# Log interaction messages with WandB if available # Log interaction messages with WandB if available
if wandb_logger: if wandb_logger:
wandb_logger.log_dict( wandb_logger.log_dict(d=interaction_message, mode="train", custom_step_key="Interaction step")
d=interaction_message,
mode="train",
custom_step_key="Interaction step"
)
logging.debug("[LEARNER] Received interactions") logging.debug("[LEARNER] Received interactions")
@ -636,7 +628,9 @@ def add_actor_information_and_train(
if dataset_repo_id is not None: if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size=batch_size) batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch = concatenate_batch_transitions(left_batch_transitions=batch, right_batch_transition=batch_offline) batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"] actions = batch["action"]
rewards = batch["reward"] rewards = batch["reward"]
@ -759,14 +753,10 @@ def add_actor_information_and_train(
if offline_replay_buffer is not None: if offline_replay_buffer is not None:
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer) training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
training_infos["Optimization step"] = optimization_step training_infos["Optimization step"] = optimization_step
# Log training metrics # Log training metrics
if wandb_logger: if wandb_logger:
wandb_logger.log_dict( wandb_logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
d=training_infos,
mode="train",
custom_step_key="Optimization step"
)
time_for_one_optimization_step = time.time() - time_for_one_optimization_step time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
@ -795,29 +785,19 @@ def add_actor_information_and_train(
interaction_step = ( interaction_step = (
interaction_message["Interaction step"] if interaction_message is not None else 0 interaction_message["Interaction step"] if interaction_message is not None else 0
) )
# Create checkpoint directory # Create checkpoint directory
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step) checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
# Save checkpoint # Save checkpoint
save_checkpoint( save_checkpoint(checkpoint_dir, optimization_step, cfg, policy, optimizers, scheduler=None)
checkpoint_dir,
optimization_step,
cfg,
policy,
optimizers,
scheduler=None
)
# Save interaction step manually # Save interaction step manually
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR) training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
os.makedirs(training_state_dir, exist_ok=True) os.makedirs(training_state_dir, exist_ok=True)
training_state = { training_state = {"step": optimization_step, "interaction_step": interaction_step}
"step": optimization_step,
"interaction_step": interaction_step
}
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt")) torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
# Update the "last" symlink # Update the "last" symlink
update_last_checkpoint(checkpoint_dir) update_last_checkpoint(checkpoint_dir)
@ -826,17 +806,13 @@ def add_actor_information_and_train(
dataset_dir = os.path.join(cfg.output_dir, "dataset") dataset_dir = os.path.join(cfg.output_dir, "dataset")
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir): if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
shutil.rmtree(dataset_dir) shutil.rmtree(dataset_dir)
# Save dataset # Save dataset
# NOTE: Handle the case where the dataset repo id is not specified in the config # NOTE: Handle the case where the dataset repo id is not specified in the config
# eg. RL training without demonstrations data # eg. RL training without demonstrations data
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
replay_buffer.to_lerobot_dataset( replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir)
repo_id=repo_id_buffer_save,
fps=fps,
root=dataset_dir
)
if offline_replay_buffer is not None: if offline_replay_buffer is not None:
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline") dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir): if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
@ -882,9 +858,7 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
params=policy.actor.parameters_to_optimize, params=policy.actor.parameters_to_optimize,
lr=cfg.policy.actor_lr, lr=cfg.policy.actor_lr,
) )
optimizer_critic = torch.optim.Adam( optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None lr_scheduler = None
optimizers = { optimizers = {
@ -898,19 +872,19 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
def train(cfg: TrainPipelineConfig, job_name: str | None = None): def train(cfg: TrainPipelineConfig, job_name: str | None = None):
""" """
Main training function that initializes and runs the training process. Main training function that initializes and runs the training process.
Args: Args:
cfg (TrainPipelineConfig): The training configuration cfg (TrainPipelineConfig): The training configuration
job_name (str | None, optional): Job name for logging. Defaults to None. job_name (str | None, optional): Job name for logging. Defaults to None.
""" """
cfg.validate() cfg.validate()
# if cfg.output_dir is None: # if cfg.output_dir is None:
# raise ValueError("Output directory must be specified in config") # raise ValueError("Output directory must be specified in config")
if job_name is None: if job_name is None:
job_name = cfg.job_name job_name = cfg.job_name
if job_name is None: if job_name is None:
raise ValueError("Job name must be specified either in config or as a parameter") raise ValueError("Job name must be specified either in config or as a parameter")
@ -920,11 +894,12 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
# Setup WandB logging if enabled # Setup WandB logging if enabled
if cfg.wandb.enable and cfg.wandb.project: if cfg.wandb.enable and cfg.wandb.project:
from lerobot.common.utils.wandb_utils import WandBLogger from lerobot.common.utils.wandb_utils import WandBLogger
wandb_logger = WandBLogger(cfg) wandb_logger = WandBLogger(cfg)
else: else:
wandb_logger = None wandb_logger = None
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
# Handle resume logic # Handle resume logic
cfg = handle_resume_logic(cfg) cfg = handle_resume_logic(cfg)
@ -944,9 +919,9 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
@parser.wrap() @parser.wrap()
def train_cli(cfg: TrainPipelineConfig): def train_cli(cfg: TrainPipelineConfig):
if not use_threads(cfg): if not use_threads(cfg):
import torch.multiprocessing as mp import torch.multiprocessing as mp
mp.set_start_method("spawn") mp.set_start_method("spawn")
# Use the job_name from the config # Use the job_name from the config

View File

@ -122,6 +122,9 @@ def make_optimizer_and_scheduler(cfg, policy):
optimizer = VQBeTOptimizer(policy, cfg) optimizer = VQBeTOptimizer(policy, cfg)
lr_scheduler = VQBeTScheduler(optimizer, cfg) lr_scheduler = VQBeTScheduler(optimizer, cfg)
elif cfg.policy.name == "hilserl_classifier":
optimizer = torch.optim.AdamW(policy.parameters(), cfg.policy.learning_rate)
lr_scheduler = None
else: else:
raise NotImplementedError() raise NotImplementedError()

View File

@ -16,7 +16,6 @@ import time
from contextlib import nullcontext from contextlib import nullcontext
from pprint import pformat from pprint import pformat
import hydra
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -32,11 +31,8 @@ from tqdm import tqdm
from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
@ -296,8 +292,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
init_logging() init_logging()
logging.info(OmegaConf.to_yaml(cfg)) logging.info(OmegaConf.to_yaml(cfg))
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
# Initialize training environment # Initialize training environment
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
set_global_seed(cfg.seed) set_global_seed(cfg.seed)