Change HILSerlRobotEnvConfig to inherit from EnvConfig

Added support for hil_serl classifier to be trained with train.py
run classifier training by python lerobot/scripts/train.py --policy.type=hilserl_classifier
fixes in find_joint_limits, control_robot, end_effector_control_utils
This commit is contained in:
Michel Aractingi 2025-03-27 10:23:14 +01:00 committed by AdilZouitine
parent 052a4acfc2
commit d0b7690bc0
13 changed files with 389 additions and 340 deletions

View File

@ -412,7 +412,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
names = ft["names"] names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) if names is not None and names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1]) shape = (shape[2], shape[0], shape[1])
elif key == "observation.environment_state": elif key == "observation.environment_state":
type = FeatureType.ENV type = FeatureType.ENV

View File

@ -68,4 +68,3 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
) )
return env return env

View File

@ -28,6 +28,7 @@ from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType from lerobot.configs.types import FeatureType
@ -58,6 +59,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
return SACPolicy return SACPolicy
elif name == "hilserl_classifier":
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
return Classifier
else: else:
raise NotImplementedError(f"Policy with name {name} is not implemented.") raise NotImplementedError(f"Policy with name {name} is not implemented.")
@ -73,6 +78,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return VQBeTConfig(**kwargs) return VQBeTConfig(**kwargs)
elif policy_type == "pi0": elif policy_type == "pi0":
return PI0Config(**kwargs) return PI0Config(**kwargs)
elif policy_type == "hilserl_classifier":
return ClassifierConfig(**kwargs)
else: else:
raise ValueError(f"Policy type '{policy_type}' is not available.") raise ValueError(f"Policy type '{policy_type}' is not available.")

View File

@ -1,12 +1,18 @@
import json from dataclasses import dataclass, field
import os from typing import Dict, List
from dataclasses import asdict, dataclass
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, PolicyFeature
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
@dataclass @dataclass
class ClassifierConfig: class ClassifierConfig(PreTrainedConfig):
"""Configuration for the Classifier model.""" """Configuration for the Classifier model."""
name: str = "hilserl_classifier"
num_classes: int = 2 num_classes: int = 2
hidden_dim: int = 256 hidden_dim: int = 256
dropout_rate: float = 0.1 dropout_rate: float = 0.1
@ -14,22 +20,35 @@ class ClassifierConfig:
device: str = "cpu" device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn" model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2 num_cameras: int = 2
learning_rate: float = 1e-4
normalization_mode = None
# output_features: Dict[str, PolicyFeature] = field(
# default_factory=lambda: {"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,))}
# )
def save_pretrained(self, save_dir): @property
"""Save config to json file.""" def observation_delta_indices(self) -> List | None:
os.makedirs(save_dir, exist_ok=True) return None
# Convert to dict and save as JSON @property
config_dict = asdict(self) def action_delta_indices(self) -> List | None:
with open(os.path.join(save_dir, "config.json"), "w") as f: return None
json.dump(config_dict, f, indent=2)
@classmethod @property
def from_pretrained(cls, pretrained_model_name_or_path): def reward_delta_indices(self) -> List | None:
"""Load config from json file.""" return None
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
with open(config_file) as f: def get_optimizer_preset(self) -> OptimizerConfig:
config_dict = json.load(f) return AdamWConfig(
lr=self.learning_rate,
weight_decay=0.01,
grad_clip_norm=1.0,
)
return cls(**config_dict) def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return None
def validate_features(self) -> None:
"""Validate feature configurations."""
# Classifier doesn't need specific feature validation
pass

View File

@ -1,11 +1,15 @@
import logging import logging
from typing import Optional from typing import Dict, Optional, Tuple
import torch import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn from torch import Tensor, nn
from .configuration_classifier import ClassifierConfig from lerobot.common.constants import OBS_IMAGE
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,25 +36,32 @@ class ClassifierOutput:
) )
class Classifier( class Classifier(PreTrainedPolicy):
nn.Module,
PyTorchModelHubMixin,
# Add Hub metadata
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "vision-classifier"],
):
"""Image classifier built on top of a pre-trained encoder.""" """Image classifier built on top of a pre-trained encoder."""
# Add name attribute for factory name = "hilserl_classifier"
name = "classifier" config_class = ClassifierConfig
def __init__(self, config: ClassifierConfig): def __init__(
self,
config: ClassifierConfig,
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
):
from transformers import AutoModel from transformers import AutoModel
super().__init__() super().__init__(config)
self.config = config self.config = config
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
# Initialize normalization (standardized with the policy framework)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# Set up encoder
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model # Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"): if hasattr(encoder, "vision_model"):
@ -81,8 +92,6 @@ class Classifier(
else: else:
raise ValueError("Unsupported CNN architecture") raise ValueError("Unsupported CNN architecture")
self.encoder = self.encoder.to(self.config.device)
def _freeze_encoder(self) -> None: def _freeze_encoder(self) -> None:
"""Freeze the encoder parameters.""" """Freeze the encoder parameters."""
for param in self.encoder.parameters(): for param in self.encoder.parameters():
@ -109,22 +118,13 @@ class Classifier(
1 if self.config.num_classes == 2 else self.config.num_classes, 1 if self.config.num_classes == 2 else self.config.num_classes,
), ),
) )
self.classifier_head = self.classifier_head.to(self.config.device)
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor: def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
"""Extract the appropriate output from the encoder.""" """Extract the appropriate output from the encoder."""
# Process images with the processor (handles resizing and normalization)
# processed = self.processor(
# images=x, # LeRobotDataset already provides proper tensor format
# return_tensors="pt",
# )
# processed = processed["pixel_values"].to(x.device)
processed = x
with torch.no_grad(): with torch.no_grad():
if self.is_cnn: if self.is_cnn:
# The HF ResNet applies pooling internally # The HF ResNet applies pooling internally
outputs = self.encoder(processed) outputs = self.encoder(x)
# Get pooled output directly # Get pooled output directly
features = outputs.pooler_output features = outputs.pooler_output
@ -132,14 +132,24 @@ class Classifier(
features = features.squeeze(-1).squeeze(-1) features = features.squeeze(-1).squeeze(-1)
return features return features
else: # Transformer models else: # Transformer models
outputs = self.encoder(processed) outputs = self.encoder(x)
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
return outputs.pooler_output return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :] return outputs.last_hidden_state[:, 0, :]
def forward(self, xs: torch.Tensor) -> ClassifierOutput: def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]:
"""Forward pass of the classifier.""" """Extract image tensors and label tensors from batch."""
# For training, we expect input to be a tensor directly from LeRobotDataset # Find image keys in input features
image_keys = [key for key in self.config.input_features if key.startswith(OBS_IMAGE)]
# Extract the images and labels
images = [batch[key] for key in image_keys]
labels = batch["next.reward"]
return images, labels
def predict(self, xs: list) -> ClassifierOutput:
"""Forward pass of the classifier for inference."""
encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs]) encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
logits = self.classifier_head(encoder_outputs) logits = self.classifier_head(encoder_outputs)
@ -151,10 +161,77 @@ class Classifier(
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs) return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
def predict_reward(self, x, threshold=0.6): def forward(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
"""Standard forward pass for training compatible with train.py."""
# Normalize inputs if needed
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract images and labels
images, labels = self.extract_images_and_labels(batch)
# Get predictions
outputs = self.predict(images)
# Calculate loss
if self.config.num_classes == 2: if self.config.num_classes == 2:
probs = self.forward(x).probabilities # Binary classification
loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels)
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
else:
# Multi-class classification
loss = nn.functional.cross_entropy(outputs.logits, labels.long())
predictions = torch.argmax(outputs.logits, dim=1)
# Calculate accuracy for logging
correct = (predictions == labels).sum().item()
total = labels.size(0)
accuracy = 100 * correct / total
# Return loss and metrics for logging
output_dict = {
"accuracy": accuracy,
"correct": correct,
"total": total,
}
return loss, output_dict
def predict_reward(self, batch, threshold=0.6):
"""Legacy method for compatibility."""
images, _ = self.extract_images_and_labels(batch)
if self.config.num_classes == 2:
probs = self.predict(images).probabilities
logging.debug(f"Predicted reward images: {probs}") logging.debug(f"Predicted reward images: {probs}")
return (probs > threshold).float() return (probs > threshold).float()
else: else:
return torch.argmax(self.forward(x).probabilities, dim=1) return torch.argmax(self.predict(images).probabilities, dim=1)
# Methods required by PreTrainedPolicy abstract class
def get_optim_params(self) -> dict:
"""Return optimizer parameters for the policy."""
return {
"params": self.parameters(),
"lr": getattr(self.config, "learning_rate", 1e-4),
"weight_decay": getattr(self.config, "weight_decay", 0.01),
}
def reset(self):
"""Reset any stateful components (required by PreTrainedPolicy)."""
# Classifier doesn't have stateful components that need resetting
pass
def select_action(self, batch: Dict[str, Tensor]) -> Tensor:
"""Return action (class prediction) based on input observation."""
images, _ = self.extract_images_and_labels(batch)
with torch.no_grad():
outputs = self.predict(images)
if self.config.num_classes == 2:
# For binary classification return 0 or 1
return (outputs.probabilities > 0.5).float()
else:
# For multi-class return the predicted class
return torch.argmax(outputs.probabilities, dim=1)

View File

@ -23,6 +23,7 @@ class FeatureType(str, Enum):
VISUAL = "VISUAL" VISUAL = "VISUAL"
ENV = "ENV" ENV = "ENV"
ACTION = "ACTION" ACTION = "ACTION"
REWARD = "REWARD"
class NormalizationMode(str, Enum): class NormalizationMode(str, Enum):

View File

@ -274,10 +274,7 @@ def record(
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
listener, events = init_keyboard_listener(assign_rewards=assign_rewards) listener, events = init_keyboard_listener(assign_rewards=cfg.assign_rewards)
if reset_follower:
initial_position = robot.follower_arms["main"].read("Present_Position")
# Execute a few seconds without recording to: # Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided, # 1. teleoperate the robot to move it in starting position if no policy provided,
@ -394,6 +391,7 @@ def control_robot(cfg: ControlPipelineConfig):
replay(robot, cfg.control) replay(robot, cfg.control)
elif isinstance(cfg.control, RemoteRobotConfig): elif isinstance(cfg.control, RemoteRobotConfig):
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
run_lekiwi(cfg.robot) run_lekiwi(cfg.robot)
if robot.is_connected: if robot.is_connected:

View File

@ -1,13 +1,13 @@
import argparse
import logging
import sys
import time
import numpy as np
import torch
from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.robot_devices.utils import busy_wait
from lerobot.scripts.server.kinematics import RobotKinematics from lerobot.scripts.server.kinematics import RobotKinematics
import logging
import time
import torch
import numpy as np
import argparse
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import make_robot_env, HILSerlRobotEnvConfig
from lerobot.common.robot_devices.robots.configs import RobotConfig
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -458,12 +458,13 @@ class GamepadControllerHID(InputController):
def test_forward_kinematics(robot, fps=10): def test_forward_kinematics(robot, fps=10):
logging.info("Testing Forward Kinematics") logging.info("Testing Forward Kinematics")
timestep = time.perf_counter() timestep = time.perf_counter()
kinematics = RobotKinematics(robot.robot_type)
while time.perf_counter() - timestep < 60.0: while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter() loop_start_time = time.perf_counter()
robot.teleop_step() robot.teleop_step()
obs = robot.capture_observation() obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy() joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions) ee_pos = kinematics.fk_gripper_tip(joint_positions)
logging.info(f"EE Position: {ee_pos[:3, 3]}") logging.info(f"EE Position: {ee_pos[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
@ -485,21 +486,19 @@ def test_inverse_kinematics(robot, fps=10):
def teleoperate_inverse_kinematics_with_leader(robot, fps=10): def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
logging.info("Testing Inverse Kinematics") logging.info("Testing Inverse Kinematics")
fk_func = RobotKinematics.fk_gripper_tip kinematics = RobotKinematics(robot.robot_type)
timestep = time.perf_counter() timestep = time.perf_counter()
while time.perf_counter() - timestep < 60.0: while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter() loop_start_time = time.perf_counter()
obs = robot.capture_observation() obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy() joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = fk_func(joint_positions) ee_pos = kinematics.fk_gripper_tip(joint_positions)
leader_joint_positions = robot.leader_arms["main"].read("Present_Position") leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
leader_ee = fk_func(leader_joint_positions) leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
desired_ee_pos = leader_ee desired_ee_pos = leader_ee
target_joint_state = RobotKinematics.ik( target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
robot.send_action(torch.from_numpy(target_joint_state)) robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}") logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
@ -513,10 +512,10 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
obs = robot.capture_observation() obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy() joint_positions = obs["observation.state"].cpu().numpy()
fk_func = RobotKinematics.fk_gripper_tip kinematics = RobotKinematics(robot.robot_type)
leader_joint_positions = robot.leader_arms["main"].read("Present_Position") leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
initial_leader_ee = fk_func(leader_joint_positions) initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
desired_ee_pos = np.diag(np.ones(4)) desired_ee_pos = np.diag(np.ones(4))
@ -525,13 +524,13 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
# Get leader state for teleoperation # Get leader state for teleoperation
leader_joint_positions = robot.leader_arms["main"].read("Present_Position") leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
leader_ee = fk_func(leader_joint_positions) leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
# Get current state # Get current state
# obs = robot.capture_observation() # obs = robot.capture_observation()
# joint_positions = obs["observation.state"].cpu().numpy() # joint_positions = obs["observation.state"].cpu().numpy()
joint_positions = robot.follower_arms["main"].read("Present_Position") joint_positions = robot.follower_arms["main"].read("Present_Position")
current_ee_pos = fk_func(joint_positions) current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Calculate delta between leader and follower end-effectors # Calculate delta between leader and follower end-effectors
# Scaling factor can be adjusted for sensitivity # Scaling factor can be adjusted for sensitivity
@ -545,9 +544,7 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
if np.any(np.abs(ee_delta[:3, 3]) > 0.01): if np.any(np.abs(ee_delta[:3, 3]) > 0.01):
# Compute joint targets via inverse kinematics # Compute joint targets via inverse kinematics
target_joint_state = RobotKinematics.ik( target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
initial_leader_ee = leader_ee.copy() initial_leader_ee = leader_ee.copy()
@ -580,7 +577,8 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None,
# Initial position capture # Initial position capture
obs = robot.capture_observation() obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy() joint_positions = obs["observation.state"].cpu().numpy()
current_ee_pos = fk_func(joint_positions) kinematics = RobotKinematics(robot.robot_type)
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Initialize desired position with current position # Initialize desired position with current position
desired_ee_pos = np.eye(4) # Identity matrix desired_ee_pos = np.eye(4) # Identity matrix
@ -595,7 +593,7 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None,
# Get currrent robot state # Get currrent robot state
joint_positions = robot.follower_arms["main"].read("Present_Position") joint_positions = robot.follower_arms["main"].read("Present_Position")
current_ee_pos = fk_func(joint_positions) current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
# Get movement deltas from the controller # Get movement deltas from the controller
delta_x, delta_y, delta_z = controller.get_deltas() delta_x, delta_y, delta_z = controller.get_deltas()
@ -612,9 +610,7 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None,
# Only send commands if there's actual movement # Only send commands if there's actual movement
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]): if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
# Compute joint targets via inverse kinematics # Compute joint targets via inverse kinematics
target_joint_state = RobotKinematics.ik( target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
# Send command to robot # Send command to robot
robot.send_action(torch.from_numpy(target_joint_state)) robot.send_action(torch.from_numpy(target_joint_state))
@ -676,7 +672,17 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
# Close the environment # Close the environment
env.close() env.close()
if __name__ == "__main__": if __name__ == "__main__":
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import (
EEActionSpaceConfig,
EnvWrapperConfig,
HILSerlRobotEnvConfig,
make_robot_env,
)
parser = argparse.ArgumentParser(description="Test end-effector control") parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument( parser.add_argument(
"--mode", "--mode",
@ -698,12 +704,6 @@ if __name__ == "__main__":
default="so100", default="so100",
help="Robot type (so100, koch, aloha, etc.)", help="Robot type (so100, koch, aloha, etc.)",
) )
parser.add_argument(
"--config-path",
type=str,
default=None,
help="Path to the config file in json format",
)
args = parser.parse_args() args = parser.parse_args()
@ -725,7 +725,10 @@ if __name__ == "__main__":
if args.mode.startswith("keyboard"): if args.mode.startswith("keyboard"):
controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05) controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
elif args.mode.startswith("gamepad"): elif args.mode.startswith("gamepad"):
controller = GamepadController(x_step_size=0.02, y_step_size=0.02, z_step_size=0.05) if sys.platform == "darwin":
controller = GamepadControllerHID(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
else:
controller = GamepadController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
# Handle mode categories # Handle mode categories
if args.mode in ["keyboard", "gamepad"]: if args.mode in ["keyboard", "gamepad"]:
@ -734,12 +737,14 @@ if __name__ == "__main__":
elif args.mode in ["keyboard_gym", "gamepad_gym"]: elif args.mode in ["keyboard_gym", "gamepad_gym"]:
# Gym environment control modes # Gym environment control modes
cfg = HILSerlRobotEnvConfig() cfg = HILSerlRobotEnvConfig(robot=robot_config, wrapper=EnvWrapperConfig())
if args.config_path is not None: cfg.wrapper.ee_action_space_params = EEActionSpaceConfig(
cfg = HILSerlRobotEnvConfig.from_json(args.config_path) x_step_size=0.03, y_step_size=0.03, z_step_size=0.03, bounds=bounds
)
cfg.wrapper.ee_action_space_params.use_gamepad = False
cfg.device = "cpu"
env = make_robot_env(cfg, robot) env = make_robot_env(cfg, robot)
teleoperate_gym_env(env, controller, fps=args.fps) teleoperate_gym_env(env, controller, fps=cfg.fps)
elif args.mode == "leader": elif args.mode == "leader":
# Leader-follower modes don't use controllers # Leader-follower modes don't use controllers

View File

@ -63,9 +63,10 @@ def find_ee_bounds(
if time.perf_counter() - start_episode_t < 5: if time.perf_counter() - start_episode_t < 5:
continue continue
kinematics = RobotKinematics(robot.robot_type)
joint_positions = robot.follower_arms["main"].read("Present_Position") joint_positions = robot.follower_arms["main"].read("Present_Position")
print(f"Joint positions: {joint_positions}") print(f"Joint positions: {joint_positions}")
ee_list.append(RobotKinematics.fk_gripper_tip(joint_positions)[:3, 3]) ee_list.append(kinematics.fk_gripper_tip(joint_positions)[:3, 3])
if display_cameras and not is_headless(): if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key] image_keys = [key for key in observation if "image" in key]
@ -81,20 +82,19 @@ def find_ee_bounds(
break break
def make_robot(robot_type="so100", mock=True): def make_robot(robot_type="so100"):
""" """
Create a robot instance using the appropriate robot config class. Create a robot instance using the appropriate robot config class.
Args: Args:
robot_type: Robot type string (e.g., "so100", "koch", "aloha") robot_type: Robot type string (e.g., "so100", "koch", "aloha")
mock: Whether to use mock mode for hardware (default: True)
Returns: Returns:
Robot instance Robot instance
""" """
# Get the appropriate robot config class based on robot_type # Get the appropriate robot config class based on robot_type
robot_config = RobotConfig.get_choice_class(robot_type)(mock=mock) robot_config = RobotConfig.get_choice_class(robot_type)(mock=False)
robot_config.leader_arms["main"].port = leader_port robot_config.leader_arms["main"].port = leader_port
robot_config.follower_arms["main"].port = follower_port robot_config.follower_arms["main"].port = follower_port
@ -122,18 +122,12 @@ if __name__ == "__main__":
default="so100", default="so100",
help="Robot type (so100, koch, aloha, etc.)", help="Robot type (so100, koch, aloha, etc.)",
) )
parser.add_argument(
"--mock",
type=int,
default=1,
help="Use mock mode for hardware simulation",
)
# Only parse known args, leaving robot config args for Hydra if used # Only parse known args, leaving robot config args for Hydra if used
args, _ = parser.parse_known_args() args = parser.parse_args()
# Create robot with the appropriate config # Create robot with the appropriate config
robot = make_robot(args.robot_type, args.mock) robot = make_robot(args.robot_type)
if args.mode == "joint": if args.mode == "joint":
find_joint_bounds(robot, args.control_time_s) find_joint_bounds(robot, args.control_time_s)

View File

@ -1,43 +1,35 @@
import logging import logging
import sys import sys
import time import time
import sys from dataclasses import dataclass
from threading import Lock from threading import Lock
from typing import Annotated, Any, Dict, Tuple from typing import Annotated, Any, Dict, Optional, Tuple
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
import torch import torch
import torchvision.transforms.functional as F # noqa: N812 import torchvision.transforms.functional as F # noqa: N812
import json
from dataclasses import dataclass
from lerobot.common.envs.utils import preprocess_observation
from lerobot.configs.train import TrainPipelineConfig
from lerobot.common.envs.configs import EnvConfig from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.robot_devices.control_utils import ( from lerobot.common.robot_devices.control_utils import (
busy_wait, busy_wait,
is_headless, is_headless,
# reset_follower_position, reset_follower_position,
) )
from typing import Optional
from lerobot.common.utils.utils import log_say
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.kinematics import RobotKinematics from lerobot.common.utils.utils import log_say
from lerobot.scripts.server.maniskill_manipulator import ManiskillEnvConfig, make_maniskill
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@dataclass @dataclass
class EEActionSpaceConfig: class EEActionSpaceConfig:
"""Configuration parameters for end-effector action space.""" """Configuration parameters for end-effector action space."""
x_step_size: float x_step_size: float
y_step_size: float y_step_size: float
z_step_size: float z_step_size: float
@ -48,6 +40,7 @@ class EEActionSpaceConfig:
@dataclass @dataclass
class EnvWrapperConfig: class EnvWrapperConfig:
"""Configuration for environment wrappers.""" """Configuration for environment wrappers."""
display_cameras: bool = False display_cameras: bool = False
delta_action: float = 0.1 delta_action: float = 0.1
use_relative_joint_positions: bool = True use_relative_joint_positions: bool = True
@ -64,12 +57,13 @@ class EnvWrapperConfig:
reward_classifier_config_file: Optional[str] = None reward_classifier_config_file: Optional[str] = None
@EnvConfig.register_subclass(name="gym_manipulator")
@dataclass @dataclass
class HILSerlRobotEnvConfig: class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment.""" """Configuration for the HILSerlRobotEnv environment."""
robot: RobotConfig
wrapper: EnvWrapperConfig robot: Optional[RobotConfig] = None
env_name: str = "real_robot" wrapper: Optional[EnvWrapperConfig] = None
fps: int = 10 fps: int = 10
mode: str = None # Either "record", "replay", None mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None repo_id: Optional[str] = None
@ -81,11 +75,9 @@ class HILSerlRobotEnvConfig:
push_to_hub: bool = True push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None pretrained_policy_name_or_path: Optional[str] = None
@classmethod def gym_kwargs(self) -> dict:
def from_json(cls, json_path: str): return {}
with open(json_path, "r") as f:
config = json.load(f)
return cls(**config)
class HILSerlRobotEnv(gym.Env): class HILSerlRobotEnv(gym.Env):
""" """
@ -580,8 +572,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
if key_crop not in self.env.observation_space.keys(): # noqa: SIM118 if key_crop not in self.env.observation_space.keys(): # noqa: SIM118
raise ValueError(f"Key {key_crop} not in observation space") raise ValueError(f"Key {key_crop} not in observation space")
for key in crop_params_dict: for key in crop_params_dict:
top, left, height, width = crop_params_dict[key] new_shape = (3, resize_size[0], resize_size[1])
new_shape = (top + height, left + width)
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape) self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
self.resize_size = resize_size self.resize_size = resize_size
@ -1097,9 +1088,7 @@ class ActionScaleWrapper(gym.ActionWrapper):
return action * self.scale_vector, is_intervention return action * self.scale_vector, is_intervention
def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv: def make_robot_env(cfg) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv:
# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv:
""" """
Factory function to create a vectorized robot environment. Factory function to create a vectorized robot environment.
@ -1111,16 +1100,16 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
Returns: Returns:
A vectorized gym environment with all the necessary wrappers applied. A vectorized gym environment with all the necessary wrappers applied.
""" """
if "maniskill" in cfg.name: # if "maniskill" in cfg.name:
from lerobot.scripts.server.maniskill_manipulator import make_maniskill # from lerobot.scripts.server.maniskill_manipulator import make_maniskill
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") # logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
env = make_maniskill( # env = make_maniskill(
cfg=cfg, # cfg=cfg,
n_envs=1, # n_envs=1,
) # )
return env # return env
robot = cfg.robot robot = make_robot_from_config(cfg.robot)
# Create base environment # Create base environment
env = HILSerlRobotEnv( env = HILSerlRobotEnv(
robot=robot, robot=robot,
@ -1150,10 +1139,7 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.ee_action_space_params is not None: if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
if ( if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
cfg.wrapper.ee_action_space_params is not None
and cfg.wrapper.ee_action_space_params.use_gamepad
):
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
env = GamepadControlWrapper( env = GamepadControlWrapper(
env=env, env=env,
@ -1169,10 +1155,7 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
reset_pose=cfg.wrapper.fixed_reset_joint_positions, reset_pose=cfg.wrapper.fixed_reset_joint_positions,
reset_time_s=cfg.wrapper.reset_time_s, reset_time_s=cfg.wrapper.reset_time_s,
) )
if ( if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
cfg.wrapper.ee_action_space_params is None
and cfg.wrapper.joint_masking_action_space is not None
):
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
env = BatchCompitableWrapper(env=env) env = BatchCompitableWrapper(env=env)
@ -1180,7 +1163,10 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
def get_classifier(cfg): def get_classifier(cfg):
if cfg.wrapper.reward_classifier_pretrained_path is None or cfg.wrapper.reward_classifier_config_file is None: if (
cfg.wrapper.reward_classifier_pretrained_path is None
or cfg.wrapper.reward_classifier_config_file is None
):
return None return None
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ( from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
@ -1258,7 +1244,8 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
# Record episodes # Record episodes
episode_index = 0 episode_index = 0
while episode_index < cfg.record_num_episodes: recorded_action = None
while episode_index < cfg.num_episodes:
obs, _ = env.reset() obs, _ = env.reset()
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
log_say(f"Recording episode {episode_index}", play_sounds=True) log_say(f"Recording episode {episode_index}", play_sounds=True)
@ -1279,16 +1266,19 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
break break
# For teleop, get action from intervention # For teleop, get action from intervention
if policy is None: recorded_action = {
action = {"action": info["action_intervention"].cpu().squeeze(0).float()} "action": info["action_intervention"].cpu().squeeze(0).float() if policy is None else action
}
# Process observation for dataset # Process observation for dataset
obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()} obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()}
obs["observation.images.side"] = torch.clamp(obs["observation.images.side"], 0, 1)
# Add frame to dataset # Add frame to dataset
frame = {**obs, **action} frame = {**obs, **recorded_action}
frame["next.reward"] = reward frame["next.reward"] = np.array([reward], dtype=np.float32)
frame["next.done"] = terminated or truncated frame["next.done"] = np.array([terminated or truncated], dtype=bool)
frame["task"] = cfg.task
dataset.add_frame(frame) dataset.add_frame(frame)
# Maintain consistent timing # Maintain consistent timing
@ -1309,9 +1299,9 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
episode_index += 1 episode_index += 1
# Finalize dataset # Finalize dataset
dataset.consolidate(run_compute_stats=True) # dataset.consolidate(run_compute_stats=True)
if cfg.push_to_hub: if cfg.push_to_hub:
dataset.push_to_hub(cfg.repo_id) dataset.push_to_hub()
def replay_episode(env, repo_id, root=None, episode=0): def replay_episode(env, repo_id, root=None, episode=0):
@ -1334,82 +1324,69 @@ def replay_episode(env, repo_id, root=None, episode=0):
busy_wait(1 / 10 - dt_s) busy_wait(1 / 10 - dt_s)
# @parser.wrap() @parser.wrap()
# def main(cfg): def main(cfg: EnvConfig):
env = make_robot_env(cfg)
# robot = make_robot_from_config(cfg.robot) if cfg.mode == "record":
policy = None
if cfg.pretrained_policy_name_or_path is not None:
from lerobot.common.policies.sac.modeling_sac import SACPolicy
# reward_classifier = None #get_classifier( policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
# # cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file policy.to(cfg.device)
# # ) policy.eval()
# user_relative_joint_positions = True
# env = make_robot_env(cfg, robot) record_dataset(
env,
policy=None,
cfg=cfg,
)
exit()
# if cfg.mode == "record": if cfg.mode == "replay":
# policy = None replay_episode(
# if cfg.pretrained_policy_name_or_path is not None: env,
# from lerobot.common.policies.sac.modeling_sac import SACPolicy cfg.replay_repo_id,
root=cfg.dataset_root,
episode=cfg.replay_episode,
)
exit()
# policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path) env.reset()
# policy.to(cfg.device)
# policy.eval()
# record_dataset( # Retrieve the robot's action space for joint commands.
# env, action_space_robot = env.action_space.spaces[0]
# cfg.repo_id,
# root=cfg.dataset_root,
# num_episodes=cfg.num_episodes,
# fps=cfg.fps,
# task_description=cfg.task,
# policy=policy,
# )
# exit()
# if cfg.mode == "replay": # Initialize the smoothed action as a random sample.
# replay_episode( smoothed_action = action_space_robot.sample()
# env,
# cfg.replay_repo_id,
# root=cfg.dataset_root,
# episode=cfg.replay_episode,
# )
# exit()
# env.reset() # Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
alpha = 1.0
# # Retrieve the robot's action space for joint commands. num_episode = 0
# action_space_robot = env.action_space.spaces[0] sucesses = []
while num_episode < 20:
start_loop_s = time.perf_counter()
# Sample a new random action from the robot's action space.
new_random_action = action_space_robot.sample()
# Update the smoothed action using an exponential moving average.
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
# # Initialize the smoothed action as a random sample. # Execute the step: wrap the NumPy action in a torch tensor.
# smoothed_action = action_space_robot.sample() obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
if terminated or truncated:
sucesses.append(reward)
env.reset()
num_episode += 1
# # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. dt_s = time.perf_counter() - start_loop_s
# # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. busy_wait(1 / cfg.fps - dt_s)
# alpha = 1.0
# num_episode = 0 logging.info(f"Success after 20 steps {sucesses}")
# sucesses = [] logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
# while num_episode < 20:
# start_loop_s = time.perf_counter()
# # Sample a new random action from the robot's action space.
# new_random_action = action_space_robot.sample()
# # Update the smoothed action using an exponential moving average.
# smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
# # Execute the step: wrap the NumPy action in a torch tensor.
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
# if terminated or truncated:
# sucesses.append(reward)
# env.reset()
# num_episode += 1
# dt_s = time.perf_counter() - start_loop_s
# busy_wait(1 / cfg.fps - dt_s)
# logging.info(f"Success after 20 steps {sucesses}")
# logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
# if __name__ == "__main__":
# main()
if __name__ == "__main__": if __name__ == "__main__":
make_robot_env() main()

View File

@ -15,12 +15,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import os
import shutil import shutil
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pprint import pformat
import os
from pathlib import Path from pathlib import Path
from pprint import pformat
import draccus import draccus
import grpc import grpc
@ -30,35 +30,42 @@ import hilserl_pb2_grpc # type: ignore
import torch import torch
from termcolor import colored from termcolor import colored
from torch import nn from torch import nn
from torch.multiprocessing import Queue from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs import parser
# TODO: Remove the import of maniskill # TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy, SACConfig from lerobot.common.policies.sac.modeling_sac import SACConfig, SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.train_utils import ( from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir, get_step_checkpoint_dir,
get_step_identifier, get_step_identifier,
load_training_state as utils_load_training_state,
save_checkpoint, save_checkpoint,
update_last_checkpoint,
save_training_state, save_training_state,
update_last_checkpoint,
)
from lerobot.common.utils.train_utils import (
load_training_state as utils_load_training_state,
) )
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device, get_safe_torch_device,
init_logging, init_logging,
) )
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.wandb_utils import WandBLogger from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import learner_service from lerobot.scripts.server import learner_service
from lerobot.scripts.server.buffer import ( from lerobot.scripts.server.buffer import (
ReplayBuffer, ReplayBuffer,
@ -70,13 +77,6 @@ from lerobot.scripts.server.buffer import (
state_to_bytes, state_to_bytes,
) )
from lerobot.scripts.server.utils import setup_process_handlers from lerobot.scripts.server.utils import setup_process_handlers
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig: def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
@ -109,8 +109,7 @@ def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
if os.path.exists(checkpoint_dir): if os.path.exists(checkpoint_dir):
raise RuntimeError( raise RuntimeError(
f"Output directory {checkpoint_dir} already exists. " f"Output directory {checkpoint_dir} already exists. Use `resume=true` to resume training."
"Use `resume=true` to resume training."
) )
return cfg return cfg
@ -189,7 +188,6 @@ def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.policy.online_steps=}") logging.info(f"{cfg.policy.online_steps=}")
@ -197,11 +195,7 @@ def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer( def initialize_replay_buffer(cfg: TrainPipelineConfig, device: str, storage_device: str) -> ReplayBuffer:
cfg: TrainPipelineConfig,
device: str,
storage_device: str
) -> ReplayBuffer:
""" """
Initialize a replay buffer, either empty or from a dataset if resuming. Initialize a replay buffer, either empty or from a dataset if resuming.
@ -512,13 +506,15 @@ def add_actor_information_and_train(
logging.info("Initializing policy") logging.info("Initializing policy")
# Get checkpoint dir for resuming # Get checkpoint dir for resuming
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None checkpoint_dir = (
os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
)
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
policy: SACPolicy = make_policy( policy: SACPolicy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
# ds_meta=cfg.dataset, # ds_meta=cfg.dataset,
env_cfg=cfg.env env_cfg=cfg.env,
) )
# Update the policy config with the grad_clip_norm value from training config if it exists # Update the policy config with the grad_clip_norm value from training config if it exists
@ -618,11 +614,7 @@ def add_actor_information_and_train(
# Log interaction messages with WandB if available # Log interaction messages with WandB if available
if wandb_logger: if wandb_logger:
wandb_logger.log_dict( wandb_logger.log_dict(d=interaction_message, mode="train", custom_step_key="Interaction step")
d=interaction_message,
mode="train",
custom_step_key="Interaction step"
)
logging.debug("[LEARNER] Received interactions") logging.debug("[LEARNER] Received interactions")
@ -636,7 +628,9 @@ def add_actor_information_and_train(
if dataset_repo_id is not None: if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size=batch_size) batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch = concatenate_batch_transitions(left_batch_transitions=batch, right_batch_transition=batch_offline) batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"] actions = batch["action"]
rewards = batch["reward"] rewards = batch["reward"]
@ -762,11 +756,7 @@ def add_actor_information_and_train(
# Log training metrics # Log training metrics
if wandb_logger: if wandb_logger:
wandb_logger.log_dict( wandb_logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
d=training_infos,
mode="train",
custom_step_key="Optimization step"
)
time_for_one_optimization_step = time.time() - time_for_one_optimization_step time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
@ -800,22 +790,12 @@ def add_actor_information_and_train(
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step) checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
# Save checkpoint # Save checkpoint
save_checkpoint( save_checkpoint(checkpoint_dir, optimization_step, cfg, policy, optimizers, scheduler=None)
checkpoint_dir,
optimization_step,
cfg,
policy,
optimizers,
scheduler=None
)
# Save interaction step manually # Save interaction step manually
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR) training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
os.makedirs(training_state_dir, exist_ok=True) os.makedirs(training_state_dir, exist_ok=True)
training_state = { training_state = {"step": optimization_step, "interaction_step": interaction_step}
"step": optimization_step,
"interaction_step": interaction_step
}
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt")) torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
# Update the "last" symlink # Update the "last" symlink
@ -831,11 +811,7 @@ def add_actor_information_and_train(
# NOTE: Handle the case where the dataset repo id is not specified in the config # NOTE: Handle the case where the dataset repo id is not specified in the config
# eg. RL training without demonstrations data # eg. RL training without demonstrations data
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
replay_buffer.to_lerobot_dataset( replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir)
repo_id=repo_id_buffer_save,
fps=fps,
root=dataset_dir
)
if offline_replay_buffer is not None: if offline_replay_buffer is not None:
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline") dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
@ -882,9 +858,7 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
params=policy.actor.parameters_to_optimize, params=policy.actor.parameters_to_optimize,
lr=cfg.policy.actor_lr, lr=cfg.policy.actor_lr,
) )
optimizer_critic = torch.optim.Adam( optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None lr_scheduler = None
optimizers = { optimizers = {
@ -920,6 +894,7 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
# Setup WandB logging if enabled # Setup WandB logging if enabled
if cfg.wandb.enable and cfg.wandb.project: if cfg.wandb.enable and cfg.wandb.project:
from lerobot.common.utils.wandb_utils import WandBLogger from lerobot.common.utils.wandb_utils import WandBLogger
wandb_logger = WandBLogger(cfg) wandb_logger = WandBLogger(cfg)
else: else:
wandb_logger = None wandb_logger = None
@ -944,9 +919,9 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
@parser.wrap() @parser.wrap()
def train_cli(cfg: TrainPipelineConfig): def train_cli(cfg: TrainPipelineConfig):
if not use_threads(cfg): if not use_threads(cfg):
import torch.multiprocessing as mp import torch.multiprocessing as mp
mp.set_start_method("spawn") mp.set_start_method("spawn")
# Use the job_name from the config # Use the job_name from the config

View File

@ -122,6 +122,9 @@ def make_optimizer_and_scheduler(cfg, policy):
optimizer = VQBeTOptimizer(policy, cfg) optimizer = VQBeTOptimizer(policy, cfg)
lr_scheduler = VQBeTScheduler(optimizer, cfg) lr_scheduler = VQBeTScheduler(optimizer, cfg)
elif cfg.policy.name == "hilserl_classifier":
optimizer = torch.optim.AdamW(policy.parameters(), cfg.policy.learning_rate)
lr_scheduler = None
else: else:
raise NotImplementedError() raise NotImplementedError()

View File

@ -16,7 +16,6 @@ import time
from contextlib import nullcontext from contextlib import nullcontext
from pprint import pformat from pprint import pformat
import hydra
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -32,11 +31,8 @@ from tqdm import tqdm
from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
@ -296,8 +292,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
init_logging() init_logging()
logging.info(OmegaConf.to_yaml(cfg)) logging.info(OmegaConf.to_yaml(cfg))
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
# Initialize training environment # Initialize training environment
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
set_global_seed(cfg.seed) set_global_seed(cfg.seed)