Change HILSerlRobotEnvConfig to inherit from EnvConfig
Added support for hil_serl classifier to be trained with train.py run classifier training by python lerobot/scripts/train.py --policy.type=hilserl_classifier fixes in find_joint_limits, control_robot, end_effector_control_utils
This commit is contained in:
parent
052a4acfc2
commit
d0b7690bc0
|
@ -412,7 +412,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
if names is not None and names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == "observation.environment_state":
|
||||
type = FeatureType.ENV
|
||||
|
|
|
@ -68,4 +68,3 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
|||
)
|
||||
|
||||
return env
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
|||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
|
@ -58,6 +59,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
|||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy
|
||||
elif name == "hilserl_classifier":
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
@ -73,6 +78,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "hilserl_classifier":
|
||||
return ClassifierConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
|
||||
@dataclass
|
||||
class ClassifierConfig:
|
||||
class ClassifierConfig(PreTrainedConfig):
|
||||
"""Configuration for the Classifier model."""
|
||||
|
||||
name: str = "hilserl_classifier"
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
dropout_rate: float = 0.1
|
||||
|
@ -14,22 +20,35 @@ class ClassifierConfig:
|
|||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
learning_rate: float = 1e-4
|
||||
normalization_mode = None
|
||||
# output_features: Dict[str, PolicyFeature] = field(
|
||||
# default_factory=lambda: {"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,))}
|
||||
# )
|
||||
|
||||
def save_pretrained(self, save_dir):
|
||||
"""Save config to json file."""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
@property
|
||||
def observation_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
# Convert to dict and save as JSON
|
||||
config_dict = asdict(self)
|
||||
with open(os.path.join(save_dir, "config.json"), "w") as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
@property
|
||||
def action_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path):
|
||||
"""Load config from json file."""
|
||||
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
|
||||
@property
|
||||
def reward_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
with open(config_file) as f:
|
||||
config_dict = json.load(f)
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.learning_rate,
|
||||
weight_decay=0.01,
|
||||
grad_clip_norm=1.0,
|
||||
)
|
||||
|
||||
return cls(**config_dict)
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate feature configurations."""
|
||||
# Classifier doesn't need specific feature validation
|
||||
pass
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.constants import OBS_IMAGE
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||
ClassifierConfig,
|
||||
)
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -32,25 +36,32 @@ class ClassifierOutput:
|
|||
)
|
||||
|
||||
|
||||
class Classifier(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
# Add Hub metadata
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "vision-classifier"],
|
||||
):
|
||||
class Classifier(PreTrainedPolicy):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
# Add name attribute for factory
|
||||
name = "classifier"
|
||||
name = "hilserl_classifier"
|
||||
config_class = ClassifierConfig
|
||||
|
||||
def __init__(self, config: ClassifierConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: ClassifierConfig,
|
||||
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__()
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
|
||||
# Initialize normalization (standardized with the policy framework)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Set up encoder
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
|
@ -81,8 +92,6 @@ class Classifier(
|
|||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
self.encoder = self.encoder.to(self.config.device)
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
|
@ -109,22 +118,13 @@ class Classifier(
|
|||
1 if self.config.num_classes == 2 else self.config.num_classes,
|
||||
),
|
||||
)
|
||||
self.classifier_head = self.classifier_head.to(self.config.device)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
# Process images with the processor (handles resizing and normalization)
|
||||
# processed = self.processor(
|
||||
# images=x, # LeRobotDataset already provides proper tensor format
|
||||
# return_tensors="pt",
|
||||
# )
|
||||
# processed = processed["pixel_values"].to(x.device)
|
||||
processed = x
|
||||
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
# The HF ResNet applies pooling internally
|
||||
outputs = self.encoder(processed)
|
||||
outputs = self.encoder(x)
|
||||
# Get pooled output directly
|
||||
features = outputs.pooler_output
|
||||
|
||||
|
@ -132,14 +132,24 @@ class Classifier(
|
|||
features = features.squeeze(-1).squeeze(-1)
|
||||
return features
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(processed)
|
||||
outputs = self.encoder(x)
|
||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||
return outputs.pooler_output
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
def forward(self, xs: torch.Tensor) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier."""
|
||||
# For training, we expect input to be a tensor directly from LeRobotDataset
|
||||
def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]:
|
||||
"""Extract image tensors and label tensors from batch."""
|
||||
# Find image keys in input features
|
||||
image_keys = [key for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
# Extract the images and labels
|
||||
images = [batch[key] for key in image_keys]
|
||||
labels = batch["next.reward"]
|
||||
|
||||
return images, labels
|
||||
|
||||
def predict(self, xs: list) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier for inference."""
|
||||
encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
|
||||
logits = self.classifier_head(encoder_outputs)
|
||||
|
||||
|
@ -151,10 +161,77 @@ class Classifier(
|
|||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def predict_reward(self, x, threshold=0.6):
|
||||
def forward(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Normalize inputs if needed
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images and labels
|
||||
images, labels = self.extract_images_and_labels(batch)
|
||||
|
||||
# Get predictions
|
||||
outputs = self.predict(images)
|
||||
|
||||
# Calculate loss
|
||||
if self.config.num_classes == 2:
|
||||
probs = self.forward(x).probabilities
|
||||
# Binary classification
|
||||
loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels)
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
# Multi-class classification
|
||||
loss = nn.functional.cross_entropy(outputs.logits, labels.long())
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
|
||||
# Calculate accuracy for logging
|
||||
correct = (predictions == labels).sum().item()
|
||||
total = labels.size(0)
|
||||
accuracy = 100 * correct / total
|
||||
|
||||
# Return loss and metrics for logging
|
||||
output_dict = {
|
||||
"accuracy": accuracy,
|
||||
"correct": correct,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
return loss, output_dict
|
||||
|
||||
def predict_reward(self, batch, threshold=0.6):
|
||||
"""Legacy method for compatibility."""
|
||||
images, _ = self.extract_images_and_labels(batch)
|
||||
if self.config.num_classes == 2:
|
||||
probs = self.predict(images).probabilities
|
||||
logging.debug(f"Predicted reward images: {probs}")
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.forward(x).probabilities, dim=1)
|
||||
return torch.argmax(self.predict(images).probabilities, dim=1)
|
||||
|
||||
# Methods required by PreTrainedPolicy abstract class
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
"""Return optimizer parameters for the policy."""
|
||||
return {
|
||||
"params": self.parameters(),
|
||||
"lr": getattr(self.config, "learning_rate", 1e-4),
|
||||
"weight_decay": getattr(self.config, "weight_decay", 0.01),
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Reset any stateful components (required by PreTrainedPolicy)."""
|
||||
# Classifier doesn't have stateful components that need resetting
|
||||
pass
|
||||
|
||||
def select_action(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
"""Return action (class prediction) based on input observation."""
|
||||
images, _ = self.extract_images_and_labels(batch)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.predict(images)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
# For binary classification return 0 or 1
|
||||
return (outputs.probabilities > 0.5).float()
|
||||
else:
|
||||
# For multi-class return the predicted class
|
||||
return torch.argmax(outputs.probabilities, dim=1)
|
||||
|
|
|
@ -23,6 +23,7 @@ class FeatureType(str, Enum):
|
|||
VISUAL = "VISUAL"
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
|
||||
|
||||
class NormalizationMode(str, Enum):
|
||||
|
|
|
@ -274,10 +274,7 @@ def record(
|
|||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||
|
||||
if reset_follower:
|
||||
initial_position = robot.follower_arms["main"].read("Present_Position")
|
||||
listener, events = init_keyboard_listener(assign_rewards=cfg.assign_rewards)
|
||||
|
||||
# Execute a few seconds without recording to:
|
||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||
|
@ -394,6 +391,7 @@ def control_robot(cfg: ControlPipelineConfig):
|
|||
replay(robot, cfg.control)
|
||||
elif isinstance(cfg.control, RemoteRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
|
||||
|
||||
run_lekiwi(cfg.robot)
|
||||
|
||||
if robot.is_connected:
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
import logging
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
import argparse
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env, HILSerlRobotEnvConfig
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
@ -458,12 +458,13 @@ class GamepadControllerHID(InputController):
|
|||
def test_forward_kinematics(robot, fps=10):
|
||||
logging.info("Testing Forward Kinematics")
|
||||
timestep = time.perf_counter()
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
robot.teleop_step()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
|
||||
ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
logging.info(f"EE Position: {ee_pos[:3, 3]}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
@ -485,21 +486,19 @@ def test_inverse_kinematics(robot, fps=10):
|
|||
|
||||
def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
|
||||
logging.info("Testing Inverse Kinematics")
|
||||
fk_func = RobotKinematics.fk_gripper_tip
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
timestep = time.perf_counter()
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = fk_func(joint_positions)
|
||||
ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
leader_ee = fk_func(leader_joint_positions)
|
||||
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
desired_ee_pos = leader_ee
|
||||
target_joint_state = RobotKinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
|
||||
)
|
||||
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
@ -513,10 +512,10 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
|
|||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
|
||||
fk_func = RobotKinematics.fk_gripper_tip
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
initial_leader_ee = fk_func(leader_joint_positions)
|
||||
initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
desired_ee_pos = np.diag(np.ones(4))
|
||||
|
||||
|
@ -525,13 +524,13 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
|
|||
|
||||
# Get leader state for teleoperation
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
leader_ee = fk_func(leader_joint_positions)
|
||||
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
# Get current state
|
||||
# obs = robot.capture_observation()
|
||||
# joint_positions = obs["observation.state"].cpu().numpy()
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = fk_func(joint_positions)
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Calculate delta between leader and follower end-effectors
|
||||
# Scaling factor can be adjusted for sensitivity
|
||||
|
@ -545,9 +544,7 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
|
|||
|
||||
if np.any(np.abs(ee_delta[:3, 3]) > 0.01):
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = RobotKinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
|
||||
)
|
||||
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
|
||||
initial_leader_ee = leader_ee.copy()
|
||||
|
||||
|
@ -580,7 +577,8 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None,
|
|||
# Initial position capture
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
current_ee_pos = fk_func(joint_positions)
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Initialize desired position with current position
|
||||
desired_ee_pos = np.eye(4) # Identity matrix
|
||||
|
@ -595,7 +593,7 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None,
|
|||
|
||||
# Get currrent robot state
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = fk_func(joint_positions)
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Get movement deltas from the controller
|
||||
delta_x, delta_y, delta_z = controller.get_deltas()
|
||||
|
@ -612,9 +610,7 @@ def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None,
|
|||
# Only send commands if there's actual movement
|
||||
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = RobotKinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
|
||||
)
|
||||
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
|
||||
# Send command to robot
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
|
@ -676,7 +672,17 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
|
|||
# Close the environment
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.gym_manipulator import (
|
||||
EEActionSpaceConfig,
|
||||
EnvWrapperConfig,
|
||||
HILSerlRobotEnvConfig,
|
||||
make_robot_env,
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
|
@ -698,12 +704,6 @@ if __name__ == "__main__":
|
|||
default="so100",
|
||||
help="Robot type (so100, koch, aloha, etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the config file in json format",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -725,7 +725,10 @@ if __name__ == "__main__":
|
|||
if args.mode.startswith("keyboard"):
|
||||
controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
|
||||
elif args.mode.startswith("gamepad"):
|
||||
controller = GamepadController(x_step_size=0.02, y_step_size=0.02, z_step_size=0.05)
|
||||
if sys.platform == "darwin":
|
||||
controller = GamepadControllerHID(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
|
||||
else:
|
||||
controller = GamepadController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
|
||||
|
||||
# Handle mode categories
|
||||
if args.mode in ["keyboard", "gamepad"]:
|
||||
|
@ -734,12 +737,14 @@ if __name__ == "__main__":
|
|||
|
||||
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
|
||||
# Gym environment control modes
|
||||
cfg = HILSerlRobotEnvConfig()
|
||||
if args.config_path is not None:
|
||||
cfg = HILSerlRobotEnvConfig.from_json(args.config_path)
|
||||
|
||||
cfg = HILSerlRobotEnvConfig(robot=robot_config, wrapper=EnvWrapperConfig())
|
||||
cfg.wrapper.ee_action_space_params = EEActionSpaceConfig(
|
||||
x_step_size=0.03, y_step_size=0.03, z_step_size=0.03, bounds=bounds
|
||||
)
|
||||
cfg.wrapper.ee_action_space_params.use_gamepad = False
|
||||
cfg.device = "cpu"
|
||||
env = make_robot_env(cfg, robot)
|
||||
teleoperate_gym_env(env, controller, fps=args.fps)
|
||||
teleoperate_gym_env(env, controller, fps=cfg.fps)
|
||||
|
||||
elif args.mode == "leader":
|
||||
# Leader-follower modes don't use controllers
|
||||
|
|
|
@ -63,9 +63,10 @@ def find_ee_bounds(
|
|||
if time.perf_counter() - start_episode_t < 5:
|
||||
continue
|
||||
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
print(f"Joint positions: {joint_positions}")
|
||||
ee_list.append(RobotKinematics.fk_gripper_tip(joint_positions)[:3, 3])
|
||||
ee_list.append(kinematics.fk_gripper_tip(joint_positions)[:3, 3])
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
|
@ -81,20 +82,19 @@ def find_ee_bounds(
|
|||
break
|
||||
|
||||
|
||||
def make_robot(robot_type="so100", mock=True):
|
||||
def make_robot(robot_type="so100"):
|
||||
"""
|
||||
Create a robot instance using the appropriate robot config class.
|
||||
|
||||
Args:
|
||||
robot_type: Robot type string (e.g., "so100", "koch", "aloha")
|
||||
mock: Whether to use mock mode for hardware (default: True)
|
||||
|
||||
Returns:
|
||||
Robot instance
|
||||
"""
|
||||
|
||||
# Get the appropriate robot config class based on robot_type
|
||||
robot_config = RobotConfig.get_choice_class(robot_type)(mock=mock)
|
||||
robot_config = RobotConfig.get_choice_class(robot_type)(mock=False)
|
||||
robot_config.leader_arms["main"].port = leader_port
|
||||
robot_config.follower_arms["main"].port = follower_port
|
||||
|
||||
|
@ -122,18 +122,12 @@ if __name__ == "__main__":
|
|||
default="so100",
|
||||
help="Robot type (so100, koch, aloha, etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mock",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Use mock mode for hardware simulation",
|
||||
)
|
||||
|
||||
|
||||
# Only parse known args, leaving robot config args for Hydra if used
|
||||
args, _ = parser.parse_known_args()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create robot with the appropriate config
|
||||
robot = make_robot(args.robot_type, args.mock)
|
||||
robot = make_robot(args.robot_type)
|
||||
|
||||
if args.mode == "joint":
|
||||
find_joint_bounds(robot, args.control_time_s)
|
||||
|
|
|
@ -1,45 +1,37 @@
|
|||
import logging
|
||||
import sys
|
||||
import time
|
||||
import sys
|
||||
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import Annotated, Any, Dict, Tuple
|
||||
from typing import Annotated, Any, Dict, Optional, Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
import json
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
busy_wait,
|
||||
is_headless,
|
||||
# reset_follower_position,
|
||||
reset_follower_position,
|
||||
)
|
||||
|
||||
from typing import Optional
|
||||
from lerobot.common.utils.utils import log_say
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
from lerobot.scripts.server.maniskill_manipulator import ManiskillEnvConfig, make_maniskill
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.common.utils.utils import log_say
|
||||
from lerobot.configs import parser
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EEActionSpaceConfig:
|
||||
"""Configuration parameters for end-effector action space."""
|
||||
|
||||
x_step_size: float
|
||||
y_step_size: float
|
||||
y_step_size: float
|
||||
z_step_size: float
|
||||
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
|
||||
use_gamepad: bool = False
|
||||
|
@ -48,6 +40,7 @@ class EEActionSpaceConfig:
|
|||
@dataclass
|
||||
class EnvWrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
display_cameras: bool = False
|
||||
delta_action: float = 0.1
|
||||
use_relative_joint_positions: bool = True
|
||||
|
@ -64,28 +57,27 @@ class EnvWrapperConfig:
|
|||
reward_classifier_config_file: Optional[str] = None
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@dataclass
|
||||
class HILSerlRobotEnvConfig:
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
"""Configuration for the HILSerlRobotEnv environment."""
|
||||
robot: RobotConfig
|
||||
wrapper: EnvWrapperConfig
|
||||
env_name: str = "real_robot"
|
||||
|
||||
robot: Optional[RobotConfig] = None
|
||||
wrapper: Optional[EnvWrapperConfig] = None
|
||||
fps: int = 10
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
task: str = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_path: str):
|
||||
with open(json_path, "r") as f:
|
||||
config = json.load(f)
|
||||
return cls(**config)
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
class HILSerlRobotEnv(gym.Env):
|
||||
"""
|
||||
|
@ -580,8 +572,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
|||
if key_crop not in self.env.observation_space.keys(): # noqa: SIM118
|
||||
raise ValueError(f"Key {key_crop} not in observation space")
|
||||
for key in crop_params_dict:
|
||||
top, left, height, width = crop_params_dict[key]
|
||||
new_shape = (top + height, left + width)
|
||||
new_shape = (3, resize_size[0], resize_size[1])
|
||||
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
|
||||
|
||||
self.resize_size = resize_size
|
||||
|
@ -1097,9 +1088,7 @@ class ActionScaleWrapper(gym.ActionWrapper):
|
|||
return action * self.scale_vector, is_intervention
|
||||
|
||||
|
||||
def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
|
||||
# def make_robot_env(cfg: TrainPipelineConfig) -> gym.vector.VectorEnv:
|
||||
# def make_robot_env(cfg: ManiskillEnvConfig) -> gym.vector.VectorEnv:
|
||||
def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
"""
|
||||
Factory function to create a vectorized robot environment.
|
||||
|
||||
|
@ -1111,16 +1100,16 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
|
|||
Returns:
|
||||
A vectorized gym environment with all the necessary wrappers applied.
|
||||
"""
|
||||
if "maniskill" in cfg.name:
|
||||
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
||||
# if "maniskill" in cfg.name:
|
||||
# from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
||||
|
||||
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
||||
env = make_maniskill(
|
||||
cfg=cfg,
|
||||
n_envs=1,
|
||||
)
|
||||
return env
|
||||
robot = cfg.robot
|
||||
# logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
||||
# env = make_maniskill(
|
||||
# cfg=cfg,
|
||||
# n_envs=1,
|
||||
# )
|
||||
# return env
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
# Create base environment
|
||||
env = HILSerlRobotEnv(
|
||||
robot=robot,
|
||||
|
@ -1150,10 +1139,7 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
|
|||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||
if cfg.wrapper.ee_action_space_params is not None:
|
||||
env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
||||
if (
|
||||
cfg.wrapper.ee_action_space_params is not None
|
||||
and cfg.wrapper.ee_action_space_params.use_gamepad
|
||||
):
|
||||
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
|
||||
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
||||
env = GamepadControlWrapper(
|
||||
env=env,
|
||||
|
@ -1169,10 +1155,7 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
|
|||
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
|
||||
reset_time_s=cfg.wrapper.reset_time_s,
|
||||
)
|
||||
if (
|
||||
cfg.wrapper.ee_action_space_params is None
|
||||
and cfg.wrapper.joint_masking_action_space is not None
|
||||
):
|
||||
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||
env = BatchCompitableWrapper(env=env)
|
||||
|
||||
|
@ -1180,7 +1163,10 @@ def make_robot_env(cfg: EnvConfig) -> gym.vector.VectorEnv:
|
|||
|
||||
|
||||
def get_classifier(cfg):
|
||||
if cfg.wrapper.reward_classifier_pretrained_path is None or cfg.wrapper.reward_classifier_config_file is None:
|
||||
if (
|
||||
cfg.wrapper.reward_classifier_pretrained_path is None
|
||||
or cfg.wrapper.reward_classifier_config_file is None
|
||||
):
|
||||
return None
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||
|
@ -1258,7 +1244,8 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
|
|||
|
||||
# Record episodes
|
||||
episode_index = 0
|
||||
while episode_index < cfg.record_num_episodes:
|
||||
recorded_action = None
|
||||
while episode_index < cfg.num_episodes:
|
||||
obs, _ = env.reset()
|
||||
start_episode_t = time.perf_counter()
|
||||
log_say(f"Recording episode {episode_index}", play_sounds=True)
|
||||
|
@ -1279,16 +1266,19 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
|
|||
break
|
||||
|
||||
# For teleop, get action from intervention
|
||||
if policy is None:
|
||||
action = {"action": info["action_intervention"].cpu().squeeze(0).float()}
|
||||
recorded_action = {
|
||||
"action": info["action_intervention"].cpu().squeeze(0).float() if policy is None else action
|
||||
}
|
||||
|
||||
# Process observation for dataset
|
||||
obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()}
|
||||
obs["observation.images.side"] = torch.clamp(obs["observation.images.side"], 0, 1)
|
||||
|
||||
# Add frame to dataset
|
||||
frame = {**obs, **action}
|
||||
frame["next.reward"] = reward
|
||||
frame["next.done"] = terminated or truncated
|
||||
frame = {**obs, **recorded_action}
|
||||
frame["next.reward"] = np.array([reward], dtype=np.float32)
|
||||
frame["next.done"] = np.array([terminated or truncated], dtype=bool)
|
||||
frame["task"] = cfg.task
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# Maintain consistent timing
|
||||
|
@ -1309,9 +1299,9 @@ def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
|
|||
episode_index += 1
|
||||
|
||||
# Finalize dataset
|
||||
dataset.consolidate(run_compute_stats=True)
|
||||
# dataset.consolidate(run_compute_stats=True)
|
||||
if cfg.push_to_hub:
|
||||
dataset.push_to_hub(cfg.repo_id)
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
def replay_episode(env, repo_id, root=None, episode=0):
|
||||
|
@ -1334,82 +1324,69 @@ def replay_episode(env, repo_id, root=None, episode=0):
|
|||
busy_wait(1 / 10 - dt_s)
|
||||
|
||||
|
||||
# @parser.wrap()
|
||||
# def main(cfg):
|
||||
@parser.wrap()
|
||||
def main(cfg: EnvConfig):
|
||||
env = make_robot_env(cfg)
|
||||
|
||||
# robot = make_robot_from_config(cfg.robot)
|
||||
if cfg.mode == "record":
|
||||
policy = None
|
||||
if cfg.pretrained_policy_name_or_path is not None:
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
# reward_classifier = None #get_classifier(
|
||||
# # cfg.wrapper.reward_classifier_pretrained_path, cfg.wrapper.reward_classifier_config_file
|
||||
# # )
|
||||
# user_relative_joint_positions = True
|
||||
policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
|
||||
policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
# env = make_robot_env(cfg, robot)
|
||||
record_dataset(
|
||||
env,
|
||||
policy=None,
|
||||
cfg=cfg,
|
||||
)
|
||||
exit()
|
||||
|
||||
# if cfg.mode == "record":
|
||||
# policy = None
|
||||
# if cfg.pretrained_policy_name_or_path is not None:
|
||||
# from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
if cfg.mode == "replay":
|
||||
replay_episode(
|
||||
env,
|
||||
cfg.replay_repo_id,
|
||||
root=cfg.dataset_root,
|
||||
episode=cfg.replay_episode,
|
||||
)
|
||||
exit()
|
||||
|
||||
# policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
|
||||
# policy.to(cfg.device)
|
||||
# policy.eval()
|
||||
env.reset()
|
||||
|
||||
# record_dataset(
|
||||
# env,
|
||||
# cfg.repo_id,
|
||||
# root=cfg.dataset_root,
|
||||
# num_episodes=cfg.num_episodes,
|
||||
# fps=cfg.fps,
|
||||
# task_description=cfg.task,
|
||||
# policy=policy,
|
||||
# )
|
||||
# exit()
|
||||
# Retrieve the robot's action space for joint commands.
|
||||
action_space_robot = env.action_space.spaces[0]
|
||||
|
||||
# if cfg.mode == "replay":
|
||||
# replay_episode(
|
||||
# env,
|
||||
# cfg.replay_repo_id,
|
||||
# root=cfg.dataset_root,
|
||||
# episode=cfg.replay_episode,
|
||||
# )
|
||||
# exit()
|
||||
# Initialize the smoothed action as a random sample.
|
||||
smoothed_action = action_space_robot.sample()
|
||||
|
||||
# env.reset()
|
||||
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
|
||||
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
|
||||
alpha = 1.0
|
||||
|
||||
# # Retrieve the robot's action space for joint commands.
|
||||
# action_space_robot = env.action_space.spaces[0]
|
||||
num_episode = 0
|
||||
sucesses = []
|
||||
while num_episode < 20:
|
||||
start_loop_s = time.perf_counter()
|
||||
# Sample a new random action from the robot's action space.
|
||||
new_random_action = action_space_robot.sample()
|
||||
# Update the smoothed action using an exponential moving average.
|
||||
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
||||
|
||||
# # Initialize the smoothed action as a random sample.
|
||||
# smoothed_action = action_space_robot.sample()
|
||||
# Execute the step: wrap the NumPy action in a torch tensor.
|
||||
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
|
||||
if terminated or truncated:
|
||||
sucesses.append(reward)
|
||||
env.reset()
|
||||
num_episode += 1
|
||||
|
||||
# # Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
|
||||
# # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
|
||||
# alpha = 1.0
|
||||
dt_s = time.perf_counter() - start_loop_s
|
||||
busy_wait(1 / cfg.fps - dt_s)
|
||||
|
||||
# num_episode = 0
|
||||
# sucesses = []
|
||||
# while num_episode < 20:
|
||||
# start_loop_s = time.perf_counter()
|
||||
# # Sample a new random action from the robot's action space.
|
||||
# new_random_action = action_space_robot.sample()
|
||||
# # Update the smoothed action using an exponential moving average.
|
||||
# smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
||||
logging.info(f"Success after 20 steps {sucesses}")
|
||||
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
|
||||
|
||||
# # Execute the step: wrap the NumPy action in a torch tensor.
|
||||
# obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
|
||||
# if terminated or truncated:
|
||||
# sucesses.append(reward)
|
||||
# env.reset()
|
||||
# num_episode += 1
|
||||
|
||||
# dt_s = time.perf_counter() - start_loop_s
|
||||
# busy_wait(1 / cfg.fps - dt_s)
|
||||
|
||||
# logging.info(f"Success after 20 steps {sucesses}")
|
||||
# logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
||||
if __name__ == "__main__":
|
||||
make_robot_env()
|
||||
main()
|
||||
|
|
|
@ -15,12 +15,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pprint import pformat
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
|
@ -30,35 +30,42 @@ import hilserl_pb2_grpc # type: ignore
|
|||
import torch
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
|
||||
from torch.multiprocessing import Queue
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from lerobot.common.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
PRETRAINED_MODEL_DIR,
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs import parser
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy, SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import SACConfig, SACPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state as utils_load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
save_training_state,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.common.utils.train_utils import (
|
||||
load_training_state as utils_load_training_state,
|
||||
)
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
)
|
||||
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.server import learner_service
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
|
@ -70,47 +77,39 @@ from lerobot.scripts.server.buffer import (
|
|||
state_to_bytes,
|
||||
)
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
from lerobot.common.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
PRETRAINED_MODEL_DIR,
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
|
||||
"""
|
||||
Handle the resume logic for training.
|
||||
|
||||
|
||||
If resume is True:
|
||||
- Verifies that a checkpoint exists
|
||||
- Loads the checkpoint configuration
|
||||
- Logs resumption details
|
||||
- Returns the checkpoint configuration
|
||||
|
||||
|
||||
If resume is False:
|
||||
- Checks if an output directory exists (to prevent accidental overwriting)
|
||||
- Returns the original configuration
|
||||
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The training configuration
|
||||
|
||||
|
||||
Returns:
|
||||
TrainPipelineConfig: The updated configuration
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists
|
||||
"""
|
||||
out_dir = cfg.output_dir
|
||||
|
||||
|
||||
# Case 1: Not resuming, but need to check if directory exists to prevent overwrites
|
||||
if not cfg.resume:
|
||||
checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||
if os.path.exists(checkpoint_dir):
|
||||
raise RuntimeError(
|
||||
f"Output directory {checkpoint_dir} already exists. "
|
||||
"Use `resume=true` to resume training."
|
||||
f"Output directory {checkpoint_dir} already exists. Use `resume=true` to resume training."
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
@ -131,7 +130,7 @@ def handle_resume_logic(cfg: TrainPipelineConfig) -> TrainPipelineConfig:
|
|||
# Load config using Draccus
|
||||
checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json")
|
||||
checkpoint_cfg = TrainPipelineConfig.from_pretrained(checkpoint_cfg_path)
|
||||
|
||||
|
||||
# Ensure resume flag is set in returned config
|
||||
checkpoint_cfg.resume = True
|
||||
return checkpoint_cfg
|
||||
|
@ -143,11 +142,11 @@ def load_training_state(
|
|||
):
|
||||
"""
|
||||
Loads the training state (optimizers, step count, etc.) from a checkpoint.
|
||||
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
optimizers (Optimizer | dict): Optimizers to load state into
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
|
||||
"""
|
||||
|
@ -156,23 +155,23 @@ def load_training_state(
|
|||
|
||||
# Construct path to the last checkpoint directory
|
||||
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||
|
||||
|
||||
logging.info(f"Loading training state from {checkpoint_dir}")
|
||||
|
||||
|
||||
try:
|
||||
# Use the utility function from train_utils which loads the optimizer state
|
||||
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
|
||||
|
||||
|
||||
# Load interaction step separately from training_state.pt
|
||||
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
|
||||
interaction_step = 0
|
||||
if os.path.exists(training_state_path):
|
||||
training_state = torch.load(training_state_path, weights_only=False)
|
||||
interaction_step = training_state.get("interaction_step", 0)
|
||||
|
||||
|
||||
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
|
||||
return step, interaction_step
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load training state: {e}")
|
||||
return None, None
|
||||
|
@ -181,7 +180,7 @@ def load_training_state(
|
|||
def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
|
||||
"""
|
||||
Log information about the training process.
|
||||
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
policy (nn.Module): Policy model
|
||||
|
@ -189,7 +188,6 @@ def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
|
|||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.policy.online_steps=}")
|
||||
|
@ -197,19 +195,15 @@ def log_training_info(cfg: TrainPipelineConfig, policy: nn.Module) -> None:
|
|||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
|
||||
def initialize_replay_buffer(
|
||||
cfg: TrainPipelineConfig,
|
||||
device: str,
|
||||
storage_device: str
|
||||
) -> ReplayBuffer:
|
||||
def initialize_replay_buffer(cfg: TrainPipelineConfig, device: str, storage_device: str) -> ReplayBuffer:
|
||||
"""
|
||||
Initialize a replay buffer, either empty or from a dataset if resuming.
|
||||
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
device (str): Device to store tensors on
|
||||
storage_device (str): Device for storage optimization
|
||||
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: Initialized replay buffer
|
||||
"""
|
||||
|
@ -224,7 +218,7 @@ def initialize_replay_buffer(
|
|||
|
||||
logging.info("Resume training load the online dataset")
|
||||
dataset_path = os.path.join(cfg.output_dir, "dataset")
|
||||
|
||||
|
||||
# NOTE: In RL is possible to not have a dataset.
|
||||
repo_id = None
|
||||
if cfg.dataset is not None:
|
||||
|
@ -250,13 +244,13 @@ def initialize_offline_replay_buffer(
|
|||
) -> ReplayBuffer:
|
||||
"""
|
||||
Initialize an offline replay buffer from a dataset.
|
||||
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
device (str): Device to store tensors on
|
||||
storage_device (str): Device for storage optimization
|
||||
active_action_dims (list[int] | None): Active action dimensions for masking
|
||||
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: Initialized offline replay buffer
|
||||
"""
|
||||
|
@ -314,7 +308,7 @@ def start_learner_threads(
|
|||
) -> None:
|
||||
"""
|
||||
Start the learner threads for training.
|
||||
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
wandb_logger (WandBLogger | None): Logger for metrics
|
||||
|
@ -512,17 +506,19 @@ def add_actor_information_and_train(
|
|||
|
||||
logging.info("Initializing policy")
|
||||
# Get checkpoint dir for resuming
|
||||
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
|
||||
checkpoint_dir = (
|
||||
os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) if cfg.resume else None
|
||||
)
|
||||
pretrained_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR) if checkpoint_dir else None
|
||||
|
||||
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
# ds_meta=cfg.dataset,
|
||||
env_cfg=cfg.env
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
|
||||
# Update the policy config with the grad_clip_norm value from training config if it exists
|
||||
clip_grad_norm_value:float = cfg.policy.grad_clip_norm
|
||||
clip_grad_norm_value: float = cfg.policy.grad_clip_norm
|
||||
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
|
@ -536,7 +532,7 @@ def add_actor_information_and_train(
|
|||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
|
||||
log_training_info(cfg=cfg, policy= policy)
|
||||
log_training_info(cfg=cfg, policy=policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
|
||||
batch_size = cfg.batch_size
|
||||
|
@ -615,14 +611,10 @@ def add_actor_information_and_train(
|
|||
interaction_message = bytes_to_python_object(interaction_message)
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
interaction_message["Interaction step"] += interaction_step_shift
|
||||
|
||||
|
||||
# Log interaction messages with WandB if available
|
||||
if wandb_logger:
|
||||
wandb_logger.log_dict(
|
||||
d=interaction_message,
|
||||
mode="train",
|
||||
custom_step_key="Interaction step"
|
||||
)
|
||||
wandb_logger.log_dict(d=interaction_message, mode="train", custom_step_key="Interaction step")
|
||||
|
||||
logging.debug("[LEARNER] Received interactions")
|
||||
|
||||
|
@ -636,7 +628,9 @@ def add_actor_information_and_train(
|
|||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
||||
batch = concatenate_batch_transitions(left_batch_transitions=batch, right_batch_transition=batch_offline)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
|
@ -759,14 +753,10 @@ def add_actor_information_and_train(
|
|||
if offline_replay_buffer is not None:
|
||||
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
|
||||
training_infos["Optimization step"] = optimization_step
|
||||
|
||||
|
||||
# Log training metrics
|
||||
if wandb_logger:
|
||||
wandb_logger.log_dict(
|
||||
d=training_infos,
|
||||
mode="train",
|
||||
custom_step_key="Optimization step"
|
||||
)
|
||||
wandb_logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
|
||||
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||
|
@ -795,29 +785,19 @@ def add_actor_information_and_train(
|
|||
interaction_step = (
|
||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
||||
)
|
||||
|
||||
|
||||
# Create checkpoint directory
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
|
||||
|
||||
|
||||
# Save checkpoint
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
optimization_step,
|
||||
cfg,
|
||||
policy,
|
||||
optimizers,
|
||||
scheduler=None
|
||||
)
|
||||
|
||||
save_checkpoint(checkpoint_dir, optimization_step, cfg, policy, optimizers, scheduler=None)
|
||||
|
||||
# Save interaction step manually
|
||||
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
|
||||
os.makedirs(training_state_dir, exist_ok=True)
|
||||
training_state = {
|
||||
"step": optimization_step,
|
||||
"interaction_step": interaction_step
|
||||
}
|
||||
training_state = {"step": optimization_step, "interaction_step": interaction_step}
|
||||
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
|
||||
|
||||
|
||||
# Update the "last" symlink
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
|
||||
|
@ -826,17 +806,13 @@ def add_actor_information_and_train(
|
|||
dataset_dir = os.path.join(cfg.output_dir, "dataset")
|
||||
if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
|
||||
shutil.rmtree(dataset_dir)
|
||||
|
||||
|
||||
# Save dataset
|
||||
# NOTE: Handle the case where the dataset repo id is not specified in the config
|
||||
# eg. RL training without demonstrations data
|
||||
# eg. RL training without demonstrations data
|
||||
repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id
|
||||
replay_buffer.to_lerobot_dataset(
|
||||
repo_id=repo_id_buffer_save,
|
||||
fps=fps,
|
||||
root=dataset_dir
|
||||
)
|
||||
|
||||
replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir)
|
||||
|
||||
if offline_replay_buffer is not None:
|
||||
dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline")
|
||||
if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir):
|
||||
|
@ -882,9 +858,7 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
|||
params=policy.actor.parameters_to_optimize,
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
|
@ -898,19 +872,19 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
|||
def train(cfg: TrainPipelineConfig, job_name: str | None = None):
|
||||
"""
|
||||
Main training function that initializes and runs the training process.
|
||||
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The training configuration
|
||||
job_name (str | None, optional): Job name for logging. Defaults to None.
|
||||
"""
|
||||
|
||||
|
||||
cfg.validate()
|
||||
# if cfg.output_dir is None:
|
||||
# raise ValueError("Output directory must be specified in config")
|
||||
|
||||
|
||||
if job_name is None:
|
||||
job_name = cfg.job_name
|
||||
|
||||
|
||||
if job_name is None:
|
||||
raise ValueError("Job name must be specified either in config or as a parameter")
|
||||
|
||||
|
@ -920,11 +894,12 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
|
|||
# Setup WandB logging if enabled
|
||||
if cfg.wandb.enable and cfg.wandb.project:
|
||||
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
else:
|
||||
wandb_logger = None
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
|
||||
|
||||
# Handle resume logic
|
||||
cfg = handle_resume_logic(cfg)
|
||||
|
||||
|
@ -944,9 +919,9 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
|
|||
|
||||
@parser.wrap()
|
||||
def train_cli(cfg: TrainPipelineConfig):
|
||||
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
# Use the job_name from the config
|
||||
|
|
|
@ -122,6 +122,9 @@ def make_optimizer_and_scheduler(cfg, policy):
|
|||
|
||||
optimizer = VQBeTOptimizer(policy, cfg)
|
||||
lr_scheduler = VQBeTScheduler(optimizer, cfg)
|
||||
elif cfg.policy.name == "hilserl_classifier":
|
||||
optimizer = torch.optim.AdamW(policy.parameters(), cfg.policy.learning_rate)
|
||||
lr_scheduler = None
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@ import time
|
|||
from contextlib import nullcontext
|
||||
from pprint import pformat
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -32,11 +31,8 @@ from tqdm import tqdm
|
|||
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||
ClassifierConfig,
|
||||
)
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
|
@ -296,8 +292,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
init_logging()
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
|
||||
# Initialize training environment
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
set_global_seed(cfg.seed)
|
||||
|
|
Loading…
Reference in New Issue