From e35546f58ec40d3f065762fdcb7f57e455314b28 Mon Sep 17 00:00:00 2001 From: Yoel Date: Mon, 9 Dec 2024 10:21:50 +0100 Subject: [PATCH 002/112] Reward classifier and training (#528) Co-authored-by: Daniel Ritchie Co-authored-by: resolver101757 Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com> Co-authored-by: Remi Co-authored-by: Michel Aractingi --- examples/12_train_hilserl_classifier.md | 83 +++++ lerobot/common/datasets/lerobot_dataset.py | 2 +- lerobot/common/logger.py | 5 +- .../classifier/configuration_classifier.py | 36 ++ .../hilserl/classifier/modeling_classifier.py | 134 ++++++++ lerobot/common/robot_devices/control_utils.py | 27 +- .../configs/policy/hilserl_classifier.yaml | 48 +++ lerobot/scripts/control_robot.py | 26 +- lerobot/scripts/train_hilserl_classifier.py | 310 ++++++++++++++++++ tests/test_train_hilserl_classifier.py | 251 ++++++++++++++ 10 files changed, 906 insertions(+), 16 deletions(-) create mode 100644 examples/12_train_hilserl_classifier.md create mode 100644 lerobot/common/policies/hilserl/classifier/configuration_classifier.py create mode 100644 lerobot/common/policies/hilserl/classifier/modeling_classifier.py create mode 100644 lerobot/configs/policy/hilserl_classifier.yaml create mode 100644 lerobot/scripts/train_hilserl_classifier.py create mode 100644 tests/test_train_hilserl_classifier.py diff --git a/examples/12_train_hilserl_classifier.md b/examples/12_train_hilserl_classifier.md new file mode 100644 index 00000000..eeaf0f2b --- /dev/null +++ b/examples/12_train_hilserl_classifier.md @@ -0,0 +1,83 @@ +# Training a HIL-SERL Reward Classifier with LeRobot + +This tutorial provides step-by-step instructions for training a reward classifier using LeRobot. + +--- + +## Training Script Overview + +LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow: + +1. **Configuration Loading** + The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.) + +2. **Dataset Initialization** + It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling. + +3. **Classifier Initialization** + A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either: + - A single probability (binary classification), or + - Logits (multi-class classification). + +4. **Training Loop Execution** + The script performs: + - Forward and backward passes, + - Optimization steps, + - Periodic logging, evaluation, and checkpoint saving. + +--- + +## Configuring with Hydra + +For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file. + +### Config File Setup + +The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment: + +Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards: + +```yaml +# Example: lerobot/configs/policy/reward_classifier.yaml +dataset_repo_id: "my_dataset_repo_id" +## Typical logs and metrics +``` +When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint. + +After that, you will see training log like this one: + +``` +[2024-11-29 18:26:36,999][root][INFO] - +Epoch 5/5 +Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%] +``` + +or evaluation log like: + +``` +Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s] +``` + +### Metrics Tracking with Weights & Biases (WandB) + +If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including: + +- **Training Metrics**: + - `train/accuracy` + - `train/loss` + - `train/dataloading_s` +- **Evaluation Metrics**: + - `eval/accuracy` + - `eval/loss` + - `eval/eval_s` + +#### Additional Features + +You can also log sample predictions during evaluation. Each logged sample will include: + +- The **input image**. +- The **predicted label**. +- The **true label**. +- The **classifier's "confidence" (logits/probability)**. + +These logs can be useful for diagnosing and debugging performance issues. diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index b32cf709..23255805 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -291,7 +291,7 @@ class LeRobotDatasetMetadata: obj.root.mkdir(parents=True, exist_ok=False) if robot is not None: - features = get_features_from_robot(robot, use_videos) + features = {**(features or {}), **get_features_from_robot(robot)} robot_type = robot.robot_type if not all(cam.fps == fps for cam in robot.cameras.values()): logging.warning( diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 3bd2df89..dec8b465 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -31,6 +31,7 @@ from termcolor import colored from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler +import wandb from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import get_global_random_state, set_global_random_state @@ -107,8 +108,6 @@ class Logger: self._wandb = None else: os.environ["WANDB_SILENT"] = "true" - import wandb - wandb_run_id = None if cfg.resume: wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir) @@ -232,7 +231,7 @@ class Logger: # TODO(alexander-soare): Add local text log. if self._wandb is not None: for k, v in d.items(): - if not isinstance(v, (int, float, str)): + if not isinstance(v, (int, float, str, wandb.Table)): logging.warning( f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.' ) diff --git a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py new file mode 100644 index 00000000..209ff659 --- /dev/null +++ b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py @@ -0,0 +1,36 @@ +import json +import os +from dataclasses import asdict, dataclass + +import torch + + +@dataclass +class ClassifierConfig: + """Configuration for the Classifier model.""" + + num_classes: int = 2 + hidden_dim: int = 256 + dropout_rate: float = 0.1 + model_name: str = "microsoft/resnet-50" + device: str = "cuda" if torch.cuda.is_available() else "mps" + model_type: str = "cnn" # "transformer" or "cnn" + + def save_pretrained(self, save_dir): + """Save config to json file.""" + os.makedirs(save_dir, exist_ok=True) + + # Convert to dict and save as JSON + config_dict = asdict(self) + with open(os.path.join(save_dir, "config.json"), "w") as f: + json.dump(config_dict, f, indent=2) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path): + """Load config from json file.""" + config_file = os.path.join(pretrained_model_name_or_path, "config.json") + + with open(config_file) as f: + config_dict = json.load(f) + + return cls(**config_dict) diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py new file mode 100644 index 00000000..dbb434a7 --- /dev/null +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -0,0 +1,134 @@ +import logging +from typing import Optional + +import torch +from huggingface_hub import PyTorchModelHubMixin +from torch import Tensor, nn +from transformers import AutoImageProcessor, AutoModel + +from .configuration_classifier import ClassifierConfig + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +class ClassifierOutput: + """Wrapper for classifier outputs with additional metadata.""" + + def __init__( + self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None + ): + self.logits = logits + self.probabilities = probabilities + self.hidden_states = hidden_states + + +class Classifier( + nn.Module, + PyTorchModelHubMixin, + # Add Hub metadata + library_name="lerobot", + repo_url="https://github.com/huggingface/lerobot", + tags=["robotics", "vision-classifier"], +): + """Image classifier built on top of a pre-trained encoder.""" + + # Add name attribute for factory + name = "classifier" + + def __init__(self, config: ClassifierConfig): + super().__init__() + self.config = config + self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True) + encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) + # Extract vision model if we're given a multimodal model + if hasattr(encoder, "vision_model"): + logging.info("Multimodal model detected - using vision encoder only") + self.encoder = encoder.vision_model + self.vision_config = encoder.config.vision_config + else: + self.encoder = encoder + self.vision_config = getattr(encoder, "config", None) + + # Model type from config + self.is_cnn = self.config.model_type == "cnn" + + # For CNNs, initialize backbone + if self.is_cnn: + self._setup_cnn_backbone() + + self._freeze_encoder() + self._build_classifier_head() + + def _setup_cnn_backbone(self): + """Set up CNN encoder""" + if hasattr(self.encoder, "fc"): + self.feature_dim = self.encoder.fc.in_features + self.encoder = nn.Sequential(*list(self.encoder.children())[:-1]) + elif hasattr(self.encoder.config, "hidden_sizes"): + self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension + else: + raise ValueError("Unsupported CNN architecture") + + def _freeze_encoder(self) -> None: + """Freeze the encoder parameters.""" + for param in self.encoder.parameters(): + param.requires_grad = False + + def _build_classifier_head(self) -> None: + """Initialize the classifier head architecture.""" + # Get input dimension based on model type + if self.is_cnn: + input_dim = self.feature_dim + else: # Transformer models + if hasattr(self.encoder.config, "hidden_size"): + input_dim = self.encoder.config.hidden_size + else: + raise ValueError("Unsupported transformer architecture since hidden_size is not found") + + self.classifier_head = nn.Sequential( + nn.Linear(input_dim, self.config.hidden_dim), + nn.Dropout(self.config.dropout_rate), + nn.LayerNorm(self.config.hidden_dim), + nn.ReLU(), + nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes), + ) + + def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor: + """Extract the appropriate output from the encoder.""" + # Process images with the processor (handles resizing and normalization) + processed = self.processor( + images=x, # LeRobotDataset already provides proper tensor format + return_tensors="pt", + ) + processed = processed["pixel_values"].to(x.device) + + with torch.no_grad(): + if self.is_cnn: + # The HF ResNet applies pooling internally + outputs = self.encoder(processed) + # Get pooled output directly + features = outputs.pooler_output + + if features.dim() > 2: + features = features.squeeze(-1).squeeze(-1) + return features + else: # Transformer models + outputs = self.encoder(processed) + if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: + return outputs.pooler_output + return outputs.last_hidden_state[:, 0, :] + + def forward(self, x: torch.Tensor) -> ClassifierOutput: + """Forward pass of the classifier.""" + # For training, we expect input to be a tensor directly from LeRobotDataset + encoder_output = self._get_encoder_output(x) + logits = self.classifier_head(encoder_output) + + if self.config.num_classes == 2: + logits = logits.squeeze(-1) + probabilities = torch.sigmoid(logits) + else: + probabilities = torch.softmax(logits, dim=-1) + + return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_output) diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 8cc0f326..911a265b 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -120,14 +120,22 @@ def predict_action(observation, policy, device, use_amp): return action -def init_keyboard_listener(): - # Allow to exit early while recording an episode or resetting the environment, - # by tapping the right arrow key '->'. This might require a sudo permission - # to allow your terminal to monitor keyboard events. +def init_keyboard_listener(assign_rewards=False): + """ + Initializes a keyboard listener to enable early termination of an episode + or environment reset by pressing the right arrow key ('->'). This may require + sudo permissions to allow the terminal to monitor keyboard events. + + Args: + assign_rewards (bool): If True, allows annotating the collected trajectory + with a binary reward at the end of the episode to indicate success. + """ events = {} events["exit_early"] = False events["rerecord_episode"] = False events["stop_recording"] = False + if assign_rewards: + events["next.reward"] = 0 if is_headless(): logging.warning( @@ -152,6 +160,13 @@ def init_keyboard_listener(): print("Escape key pressed. Stopping data recording...") events["stop_recording"] = True events["exit_early"] = True + elif assign_rewards and key == keyboard.Key.space: + events["next.reward"] = 1 if events["next.reward"] == 0 else 0 + print( + "Space key pressed. Assigning new reward to the subsequent frames. New reward:", + events["next.reward"], + ) + except Exception as e: print(f"Error handling key press: {e}") @@ -272,6 +287,8 @@ def control_loop( if dataset is not None: frame = {**observation, **action} + if "next.reward" in events: + frame["next.reward"] = events["next.reward"] dataset.add_frame(frame) if display_cameras and not is_headless(): @@ -301,6 +318,8 @@ def reset_environment(robot, events, reset_time_s): timestamp = 0 start_vencod_t = time.perf_counter() + if "next.reward" in events: + events["next.reward"] = 0 # Wait if necessary with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar: diff --git a/lerobot/configs/policy/hilserl_classifier.yaml b/lerobot/configs/policy/hilserl_classifier.yaml new file mode 100644 index 00000000..be82bc4e --- /dev/null +++ b/lerobot/configs/policy/hilserl_classifier.yaml @@ -0,0 +1,48 @@ +# @package _global_ + +defaults: + - _self_ + +seed: 13 +dataset_repo_id: "dataset_repo_id" +train_split_proportion: 0.8 + +# Required by logger +env: + name: "classifier" + task: "binary_classification" + + +training: + num_epochs: 5 + batch_size: 16 + learning_rate: 1e-4 + num_workers: 4 + grad_clip_norm: 10 + use_amp: true + log_freq: 1 + eval_freq: 1 # How often to run validation (in epochs) + save_freq: 1 # How often to save checkpoints (in epochs) + save_checkpoint: true + image_key: "observation.images.phone" + label_key: "next.reward" + +eval: + batch_size: 16 + num_samples_to_log: 30 # Number of validation samples to log in the table + +policy: + name: "hilserl/classifier" + model_name: "facebook/convnext-base-224" + model_type: "cnn" + +wandb: + enable: false + project: "classifier-training" + entity: "wandb_entity" + job_name: "classifier_training_0" + disable_artifact: false + +device: "mps" +resume: false +output_dir: "output" diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 12eaf146..45a6bd66 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -191,6 +191,7 @@ def record( single_task: str, pretrained_policy_name_or_path: str | None = None, policy_overrides: List[str] | None = None, + assign_rewards: bool = False, fps: int | None = None, warmup_time_s: int | float = 2, episode_time_s: int | float = 10, @@ -214,6 +215,9 @@ def record( policy = None device = None use_amp = None + extra_features = ( + {"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None + ) if single_task: task = single_task @@ -254,12 +258,12 @@ def record( use_videos=video, image_writer_processes=num_image_writer_processes, image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras), + features=extra_features, ) if not robot.is_connected: robot.connect() - - listener, events = init_keyboard_listener() + listener, events = init_keyboard_listener(assign_rewards=assign_rewards) # Execute a few seconds without recording to: # 1. teleoperate the robot to move it in starting position if no policy provided, @@ -469,12 +473,12 @@ if __name__ == "__main__": default=1, help="Upload dataset to Hugging Face hub.", ) - parser_record.add_argument( - "--tags", - type=str, - nargs="*", - help="Add tags to your dataset on the hub.", - ) + # parser_record.add_argument( + # "--tags", + # type=str, + # nargs="*", + # help="Add tags to your dataset on the hub.", + # ) parser_record.add_argument( "--num-image-writer-processes", type=int, @@ -517,6 +521,12 @@ if __name__ == "__main__": nargs="*", help="Any key=value arguments to override config values (use dots for.nested=overrides)", ) + parser_record.add_argument( + "--assign-rewards", + type=int, + default=0, + help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.", + ) parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay.add_argument( diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py new file mode 100644 index 00000000..8dea68c6 --- /dev/null +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import time +from contextlib import nullcontext +from pathlib import Path +from pprint import pformat + +import hydra +import torch +import torch.nn as nn +from deepdiff import DeepDiff +from omegaconf import DictConfig, OmegaConf +from termcolor import colored +from torch import optim +from torch.cuda.amp import GradScaler +from torch.utils.data import DataLoader, WeightedRandomSampler, random_split +from tqdm import tqdm + +import wandb +from lerobot.common.datasets.factory import resolve_delta_timestamps +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.logger import Logger +from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg +from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig +from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier +from lerobot.common.utils.utils import ( + format_big_number, + get_safe_torch_device, + init_hydra_config, + set_global_seed, +) + + +def get_model(cfg, logger): + classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) + model = Classifier(classifier_config) + if cfg.resume: + model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict()) + return model + + +def create_balanced_sampler(dataset, cfg): + # Creates a weighted sampler to handle class imbalance + + labels = torch.tensor([item[cfg.training.label_key] for item in dataset]) + _, counts = torch.unique(labels, return_counts=True) + class_weights = 1.0 / counts.float() + sample_weights = class_weights[labels] + + return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True) + + +def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg): + # Single epoch training loop with AMP support and progress tracking + model.train() + correct = 0 + total = 0 + + pbar = tqdm(train_loader, desc="Training") + for batch_idx, batch in enumerate(pbar): + start_time = time.perf_counter() + images = batch[cfg.training.image_key].to(device) + labels = batch[cfg.training.label_key].float().to(device) + + # Forward pass with optional AMP + with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext(): + outputs = model(images) + loss = criterion(outputs.logits, labels) + + # Backward pass with gradient scaling if AMP enabled + optimizer.zero_grad() + if cfg.training.use_amp: + grad_scaler.scale(loss).backward() + grad_scaler.step(optimizer) + grad_scaler.update() + else: + loss.backward() + optimizer.step() + + # Track metrics + if model.config.num_classes == 2: + predictions = (torch.sigmoid(outputs.logits) > 0.5).float() + else: + predictions = torch.argmax(outputs.logits, dim=1) + correct += (predictions == labels).sum().item() + total += labels.size(0) + + current_acc = 100 * correct / total + train_info = { + "loss": loss.item(), + "accuracy": current_acc, + "dataloading_s": time.perf_counter() - start_time, + } + + logger.log_dict(train_info, step + batch_idx, mode="train") + pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"}) + + +def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8): + # Validation loop with metric tracking and sample logging + model.eval() + correct = 0 + total = 0 + batch_start_time = time.perf_counter() + samples = [] + running_loss = 0 + + with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext(): + for batch in tqdm(val_loader, desc="Validation"): + images = batch[cfg.training.image_key].to(device) + labels = batch[cfg.training.label_key].float().to(device) + + outputs = model(images) + loss = criterion(outputs.logits, labels) + + # Track metrics + if model.config.num_classes == 2: + predictions = (torch.sigmoid(outputs.logits) > 0.5).float() + else: + predictions = torch.argmax(outputs.logits, dim=1) + correct += (predictions == labels).sum().item() + total += labels.size(0) + running_loss += loss.item() + + # Log sample predictions for visualization + if len(samples) < num_samples_to_log: + for i in range(min(num_samples_to_log - len(samples), len(images))): + if model.config.num_classes == 2: + confidence = round(outputs.probabilities[i].item(), 3) + else: + confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()] + samples.append( + { + "image": wandb.Image(images[i].cpu()), + "true_label": labels[i].item(), + "predicted": predictions[i].item(), + "confidence": confidence, + } + ) + + accuracy = 100 * correct / total + avg_loss = running_loss / len(val_loader) + + eval_info = { + "loss": avg_loss, + "accuracy": accuracy, + "eval_s": time.perf_counter() - batch_start_time, + "eval/prediction_samples": wandb.Table( + data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples], + columns=["Image", "True Label", "Predicted", "Confidence"], + ) + if logger._cfg.wandb.enable + else None, + } + + return accuracy, eval_info + + +@hydra.main(version_base="1.2", config_path="../configs", config_name="classifier") +def train(cfg: DictConfig) -> None: + # Main training pipeline with support for resuming training + logging.info(OmegaConf.to_yaml(cfg)) + + # Initialize training environment + device = get_safe_torch_device(cfg.device, log=True) + set_global_seed(cfg.seed) + + out_dir = Path(cfg.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None) + + # Setup dataset and dataloaders + dataset = LeRobotDataset(cfg.dataset_repo_id) + logging.info(f"Dataset size: {len(dataset)}") + + train_size = int(cfg.train_split_proportion * len(dataset)) + val_size = len(dataset) - train_size + train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) + + sampler = create_balanced_sampler(train_dataset, cfg) + train_loader = DataLoader( + train_dataset, + batch_size=cfg.training.batch_size, + num_workers=cfg.training.num_workers, + sampler=sampler, + pin_memory=True, + ) + + val_loader = DataLoader( + val_dataset, + batch_size=cfg.eval.batch_size, + shuffle=False, + num_workers=cfg.training.num_workers, + pin_memory=True, + ) + + # Resume training if requested + step = 0 + best_val_acc = 0 + + if cfg.resume: + if not Logger.get_last_checkpoint_dir(out_dir).exists(): + raise RuntimeError( + "You have set resume=True, but there is no model checkpoint in " + f"{Logger.get_last_checkpoint_dir(out_dir)}" + ) + checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") + logging.info( + colored( + "You have set resume=True, indicating that you wish to resume a run", + color="yellow", + attrs=["bold"], + ) + ) + # Load and validate checkpoint configuration + checkpoint_cfg = init_hydra_config(checkpoint_cfg_path) + # Check for differences between the checkpoint configuration and provided configuration. + # Hack to resolve the delta_timestamps ahead of time in order to properly diff. + resolve_delta_timestamps(cfg) + diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) + # Ignore the `resume` and parameters. + if "values_changed" in diff and "root['resume']" in diff["values_changed"]: + del diff["values_changed"]["root['resume']"] + if len(diff) > 0: + logging.warning( + "At least one difference was detected between the checkpoint configuration and " + f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration " + "takes precedence.", + ) + # Use the checkpoint config instead of the provided config (but keep `resume` parameter). + cfg = checkpoint_cfg + cfg.resume = True + + # Initialize model and training components + model = get_model(cfg=cfg, logger=logger).to(device) + + optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate) + # Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class + criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss() + grad_scaler = GradScaler(enabled=cfg.training.use_amp) + + # Log model parameters + num_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + num_total_params = sum(p.numel() for p in model.parameters()) + logging.info(f"Learnable parameters: {format_big_number(num_learnable_params)}") + logging.info(f"Total parameters: {format_big_number(num_total_params)}") + + if cfg.resume: + step = logger.load_last_training_state(optimizer, None) + + # Training loop with validation and checkpointing + for epoch in range(cfg.training.num_epochs): + logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}") + + train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg) + + # Periodic validation + if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0: + val_acc, eval_info = validate( + model, + val_loader, + criterion, + device, + logger, + cfg, + ) + logger.log_dict(eval_info, step + len(train_loader), mode="eval") + + # Save best model + if val_acc > best_val_acc: + best_val_acc = val_acc + logger.save_checkpoint( + train_step=step + len(train_loader), + policy=model, + optimizer=optimizer, + scheduler=None, + identifier="best", + ) + + # Periodic checkpointing + if cfg.training.save_checkpoint and (epoch + 1) % cfg.training.save_freq == 0: + logger.save_checkpoint( + train_step=step + len(train_loader), + policy=model, + optimizer=optimizer, + scheduler=None, + identifier=f"{epoch+1:06d}", + ) + + step += len(train_loader) + + logging.info("Training completed") + + +if __name__ == "__main__": + train() diff --git a/tests/test_train_hilserl_classifier.py b/tests/test_train_hilserl_classifier.py new file mode 100644 index 00000000..66d8fbe4 --- /dev/null +++ b/tests/test_train_hilserl_classifier.py @@ -0,0 +1,251 @@ +import os +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +import torch +from hydra import compose, initialize_config_dir +from torch import nn +from torch.utils.data import Dataset + +from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig +from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier +from lerobot.scripts.train_hilserl_classifier import ( + create_balanced_sampler, + train, + train_epoch, + validate, +) + + +class MockDataset(Dataset): + def __init__(self, data): + self.data = data + self.meta = MagicMock() + self.meta.stats = {} + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return len(self.data) + + +def make_dummy_model(): + model_config = ClassifierConfig(num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel") + model = Classifier(config=model_config) + return model + + +def test_create_balanced_sampler(): + # Mock dataset with imbalanced classes + data = [ + {"label": 0}, + {"label": 0}, + {"label": 1}, + {"label": 0}, + {"label": 1}, + {"label": 1}, + {"label": 1}, + {"label": 1}, + ] + dataset = MockDataset(data) + cfg = MagicMock() + cfg.training.label_key = "label" + + sampler = create_balanced_sampler(dataset, cfg) + + # Get weights from the sampler + weights = sampler.weights.float() + + # Check that samples have appropriate weights + labels = [item["label"] for item in data] + class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32) + class_weights = 1.0 / class_counts + expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32) + + # Test that the weights are correct + assert torch.allclose(weights, expected_weights) + + +def test_train_epoch(): + model = make_dummy_model() + # Mock components + model.train = MagicMock() + + train_loader = [ + { + "image": torch.rand(2, 3, 224, 224), + "label": torch.tensor([0.0, 1.0]), + } + ] + + criterion = nn.BCEWithLogitsLoss() + optimizer = MagicMock() + grad_scaler = MagicMock() + device = torch.device("cpu") + logger = MagicMock() + step = 0 + cfg = MagicMock() + cfg.training.image_key = "image" + cfg.training.label_key = "label" + cfg.training.use_amp = False + + # Call the function under test + train_epoch( + model, + train_loader, + criterion, + optimizer, + grad_scaler, + device, + logger, + step, + cfg, + ) + + # Check that model.train() was called + model.train.assert_called_once() + + # Check that optimizer.zero_grad() was called + optimizer.zero_grad.assert_called() + + # Check that logger.log_dict was called + logger.log_dict.assert_called() + + +def test_validate(): + model = make_dummy_model() + + # Mock components + model.eval = MagicMock() + val_loader = [ + { + "image": torch.rand(2, 3, 224, 224), + "label": torch.tensor([0.0, 1.0]), + } + ] + criterion = nn.BCEWithLogitsLoss() + device = torch.device("cpu") + logger = MagicMock() + cfg = MagicMock() + cfg.training.image_key = "image" + cfg.training.label_key = "label" + cfg.training.use_amp = False + + # Call validate + accuracy, eval_info = validate(model, val_loader, criterion, device, logger, cfg) + + # Check that model.eval() was called + model.eval.assert_called_once() + + # Check accuracy/eval_info are calculated and of the correct type + assert isinstance(accuracy, float) + assert isinstance(eval_info, dict) + + +@pytest.mark.parametrize("resume", [True, False]) +@patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config") +@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir") +@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir") +@patch("lerobot.scripts.train_hilserl_classifier.Logger") +@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset") +@patch("lerobot.scripts.train_hilserl_classifier.make_policy") +def test_resume_function( + mock_make_policy, + mock_dataset, + mock_logger, + mock_get_last_pretrained_model_dir, + mock_get_last_checkpoint_dir, + mock_init_hydra_config, + resume, +): + # Initialize Hydra + test_file_dir = os.path.dirname(os.path.abspath(__file__)) + config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy")) + assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}" + + with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"): + cfg = compose( + config_name="reward_classifier", + overrides=[ + "device=cpu", + "seed=42", + f"output_dir={tempfile.mkdtemp()}", + "wandb.enable=False", + f"resume={resume}", + "dataset_repo_id=dataset_repo_id", + "train_split_proportion=0.8", + "training.num_workers=0", + "training.batch_size=2", + "training.image_key=image", + "training.label_key=label", + "training.use_amp=False", + "training.num_epochs=1", + "eval.batch_size=2", + ], + ) + + # Mock the init_hydra_config function to return cfg + mock_init_hydra_config.return_value = cfg + + # Mock dataset + dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)]) + mock_dataset.return_value = dataset + + # Mock checkpoint handling + mock_checkpoint_dir = MagicMock(spec=Path) + mock_checkpoint_dir.exists.return_value = resume # Only exists if resuming + mock_get_last_checkpoint_dir.return_value = mock_checkpoint_dir + mock_get_last_pretrained_model_dir.return_value = Path(tempfile.mkdtemp()) + + # Mock logger + logger = MagicMock() + resumed_step = 1000 + if resume: + logger.load_last_training_state.return_value = resumed_step + else: + logger.load_last_training_state.return_value = 0 + mock_logger.return_value = logger + + # Instantiate the model and set make_policy to return it + model = make_dummy_model() + mock_make_policy.return_value = model + + # Call train + train(cfg) + + # Check that checkpoint handling methods were called + if resume: + mock_get_last_checkpoint_dir.assert_called_once_with(Path(cfg.output_dir)) + mock_get_last_pretrained_model_dir.assert_called_once_with(Path(cfg.output_dir)) + mock_checkpoint_dir.exists.assert_called_once() + logger.load_last_training_state.assert_called_once() + else: + mock_get_last_checkpoint_dir.assert_not_called() + mock_get_last_pretrained_model_dir.assert_not_called() + mock_checkpoint_dir.exists.assert_not_called() + logger.load_last_training_state.assert_not_called() + + # Collect the steps from logger.log_dict calls + train_log_calls = logger.log_dict.call_args_list + + # Extract the steps used in the train logging + steps = [] + for call in train_log_calls: + mode = call.kwargs.get("mode", call.args[2] if len(call.args) > 2 else None) + if mode == "train": + step = call.kwargs.get("step", call.args[1] if len(call.args) > 1 else None) + steps.append(step) + + expected_start_step = resumed_step if resume else 0 + + # Calculate expected_steps + train_size = int(cfg.train_split_proportion * len(dataset)) + batch_size = cfg.training.batch_size + num_batches = (train_size + batch_size - 1) // batch_size + + expected_steps = [expected_start_step + i for i in range(num_batches)] + + assert steps == expected_steps, f"Expected steps {expected_steps}, got {steps}" From 7fcf638c0d350aa40ac6cfed46ed4b285647b7ea Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 9 Dec 2024 19:17:47 +0100 Subject: [PATCH 003/112] Add human intervention mechanism and eval_robot script to evaluate policy on the robot (#541) Co-authored-by: Yoel --- lerobot/configs/robot/koch.yaml | 4 +- lerobot/configs/robot/so100.yaml | 2 +- lerobot/scripts/eval_on_robot.py | 335 +++++++++++++++++++++++++++++++ 3 files changed, 338 insertions(+), 3 deletions(-) create mode 100644 lerobot/scripts/eval_on_robot.py diff --git a/lerobot/configs/robot/koch.yaml b/lerobot/configs/robot/koch.yaml index 40969dc7..334db830 100644 --- a/lerobot/configs/robot/koch.yaml +++ b/lerobot/configs/robot/koch.yaml @@ -10,7 +10,7 @@ max_relative_target: null leader_arms: main: _target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus - port: /dev/tty.usbmodem575E0031751 + port: /dev/tty.usbmodem58760430441 motors: # name: (index, model) shoulder_pan: [1, "xl330-m077"] @@ -23,7 +23,7 @@ leader_arms: follower_arms: main: _target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus - port: /dev/tty.usbmodem575E0032081 + port: /dev/tty.usbmodem585A0083391 motors: # name: (index, model) shoulder_pan: [1, "xl430-w250"] diff --git a/lerobot/configs/robot/so100.yaml b/lerobot/configs/robot/so100.yaml index ec6f3e3f..0978de64 100644 --- a/lerobot/configs/robot/so100.yaml +++ b/lerobot/configs/robot/so100.yaml @@ -18,7 +18,7 @@ max_relative_target: null leader_arms: main: _target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus - port: /dev/tty.usbmodem585A0077581 + port: /dev/tty.usbmodem58760433331 motors: # name: (index, model) shoulder_pan: [1, "sts3215"] diff --git a/lerobot/scripts/eval_on_robot.py b/lerobot/scripts/eval_on_robot.py new file mode 100644 index 00000000..6a790f0a --- /dev/null +++ b/lerobot/scripts/eval_on_robot.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluate a policy by running rollouts on the real robot and computing metrics. + +Usage examples: evaluate a checkpoint from the LeRobot training script for 10 episodes. + +``` +python lerobot/scripts/eval_on_robot.py \ + -p outputs/train/model/checkpoints/005000/pretrained_model \ + eval.n_episodes=10 +``` + +**NOTE** (michel-aractingi): This script is incomplete and it is being prepared +for running training on the real robot. +""" + +import argparse +import logging +import time +from copy import deepcopy + +import numpy as np +import torch +from tqdm import trange + +from lerobot.common.policies.policy_protocol import Policy +from lerobot.common.robot_devices.control_utils import busy_wait, is_headless +from lerobot.common.robot_devices.robots.factory import Robot, make_robot +from lerobot.common.utils.utils import ( + init_hydra_config, + init_logging, + log_say, +) + + +def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict: + """Run a batched policy rollout on the real robot. + + The return dictionary contains: + "robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation + keys. NOTE the that this has an extra sequence element relative to the other keys in the + dictionary. This is because an extra observation is included for after the environment is + terminated or truncated. + "action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not + including the last observations). + "reward": A (batch, sequence) tensor of rewards received for applying the actions. + "success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon + environment termination/truncation). + "done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element, + the first True is followed by True's all the way till the end. This can be used for masking + extraneous elements from the sequences above. + + Args: + robot: The robot class that defines the interface with the real robot. + policy: The policy. Must be a PyTorch nn module. + + Returns: + The dictionary described above. + """ + # assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module." + # device = get_device_from_parameters(policy) + + # define keyboard listener + listener, events = init_keyboard_listener() + + # Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready. + # policy.reset() + + # Get observation from real robot + observation = robot.capture_observation() + + # Calculate reward. TODO (michel-aractingi) + # in HIL-SERL it will be with a reward classifier + reward = calculate_reward(observation) + all_observations = [] + all_actions = [] + all_rewards = [] + all_successes = [] + + start_episode_t = time.perf_counter() + timestamp = 0.0 + while timestamp < control_time_s: + start_loop_t = time.perf_counter() + + all_observations.append(deepcopy(observation)) + # observation = {key: observation[key].to(device, non_blocking=True) for key in observation} + + # Apply the next action. + while events["pause_policy"] and not events["human_intervention_step"]: + busy_wait(0.5) + + if events["human_intervention_step"]: + # take over the robot's actions + observation, action = robot.teleop_step(record_data=True) + action = action["action"] # teleop step returns torch tensors but in a dict + else: + # explore with policy + with torch.inference_mode(): + action = robot.follower_arms["main"].read("Present_Position") + action = torch.from_numpy(action) + robot.send_action(action) + # action = predict_action(observation, policy, device, use_amp) + + observation = robot.capture_observation() + # Calculate reward + # in HIL-SERL it will be with a reward classifier + reward = calculate_reward(observation) + + all_actions.append(action) + all_rewards.append(torch.from_numpy(reward)) + all_successes.append(torch.tensor([False])) + + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / fps - dt_s) + timestamp = time.perf_counter() - start_episode_t + if events["exit_early"]: + events["exit_early"] = False + events["human_intervention_step"] = False + events["pause_policy"] = False + break + all_observations.append(deepcopy(observation)) + + dones = torch.tensor([False] * len(all_actions)) + dones[-1] = True + # Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors. + ret = { + "action": torch.stack(all_actions, dim=1), + "next.reward": torch.stack(all_rewards, dim=1), + "next.success": torch.stack(all_successes, dim=1), + "done": dones, + } + stacked_observations = {} + for key in all_observations[0]: + stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) + ret["observation"] = stacked_observations + + listener.stop() + + return ret + + +def eval_policy( + robot: Robot, + policy: torch.nn.Module, + fps: float, + n_episodes: int, + control_time_s: int = 20, + use_amp: bool = True, +) -> dict: + """ + Args: + env: The batch of environments. + policy: The policy. + n_episodes: The number of episodes to evaluate. + Returns: + Dictionary with metrics and data regarding the rollouts. + """ + # TODO (michel-aractingi) comment this out for testing with a fixed policy + # assert isinstance(policy, Policy) + # policy.eval() + + sum_rewards = [] + max_rewards = [] + successes = [] + rollouts = [] + + start_eval = time.perf_counter() + progbar = trange(n_episodes, desc="Evaluating policy on real robot") + for _batch_idx in progbar: + rollout_data = rollout(robot, policy, fps, control_time_s, use_amp) + + rollouts.append(rollout_data) + sum_rewards.append(sum(rollout_data["next.reward"])) + max_rewards.append(max(rollout_data["next.reward"])) + successes.append(rollout_data["next.success"][-1]) + + info = { + "per_episode": [ + { + "episode_ix": i, + "sum_reward": sum_reward, + "max_reward": max_reward, + "pc_success": success * 100, + } + for i, (sum_reward, max_reward, success) in enumerate( + zip( + sum_rewards[:n_episodes], + max_rewards[:n_episodes], + successes[:n_episodes], + strict=False, + ) + ) + ], + "aggregated": { + "avg_sum_reward": float(np.nanmean(torch.cat(sum_rewards[:n_episodes]))), + "avg_max_reward": float(np.nanmean(torch.cat(max_rewards[:n_episodes]))), + "pc_success": float(np.nanmean(torch.cat(successes[:n_episodes])) * 100), + "eval_s": time.time() - start_eval, + "eval_ep_s": (time.time() - start_eval) / n_episodes, + }, + } + + if robot.is_connected: + robot.disconnect() + + return info + + +def calculate_reward(observation): + """ + Method to calculate reward function in some way. + In HIL-SERL this is done through defining a reward classifier + """ + # reward = reward_classifier(observation) + return np.array([0.0]) + + +def init_keyboard_listener(): + # Allow to exit early while recording an episode or resetting the environment, + # by tapping the right arrow key '->'. This might require a sudo permission + # to allow your terminal to monitor keyboard events. + events = {} + events["exit_early"] = False + events["rerecord_episode"] = False + events["pause_policy"] = False + events["human_intervention_step"] = False + + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + listener = None + return listener, events + + # Only import pynput if not in a headless environment + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.right: + print("Right arrow key pressed. Exiting loop...") + events["exit_early"] = True + elif key == keyboard.Key.left: + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + events["rerecord_episode"] = True + events["exit_early"] = True + elif key == keyboard.Key.space: + # check if first space press then pause the policy for the user to get ready + # if second space press then the user is ready to start intervention + if not events["pause_policy"]: + print( + "Space key pressed. Human intervention required.\n" + "Place the leader in similar pose to the follower and press space again." + ) + events["pause_policy"] = True + log_say("Human intervention stage. Get ready to take over.", play_sounds=True) + else: + events["human_intervention_step"] = True + print("Space key pressed. Human intervention starting.") + log_say("Starting human intervention.", play_sounds=True) + + except Exception as e: + print(f"Error handling key press: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + + return listener, events + + +if __name__ == "__main__": + init_logging() + + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "--robot-path", + type=str, + default="lerobot/configs/robot/koch.yaml", + help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.", + ) + group.add_argument( + "--robot-overrides", + type=str, + nargs="*", + help="Any key=value arguments to override config values (use dots for.nested=overrides)", + ) + group.add_argument( + "-p", + "--pretrained-policy-name-or-path", + help=( + "Either the repo ID of a model hosted on the Hub or a path to a directory containing weights " + "saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch " + "(useful for debugging). This argument is mutually exclusive with `--config`." + ), + ) + group.add_argument( + "--config", + help=( + "Path to a yaml config you want to use for initializing a policy from scratch (useful for " + "debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)." + ), + ) + parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.") + parser.add_argument( + "--out-dir", + help=( + "Where to save the evaluation outputs. If not provided, outputs are saved in " + "outputs/eval/{timestamp}_{env_name}_{policy_name}" + ), + ) + + args = parser.parse_args() + + robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) + robot = make_robot(robot_cfg) + if not robot.is_connected: + robot.connect() + + eval_policy(robot, None, fps=40, n_episodes=2, control_time_s=100) From 1020bc3108b79f6f9c6d5a6d3e7ea241419dc8fe Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Tue, 17 Dec 2024 02:42:53 +0700 Subject: [PATCH 004/112] Fixup --- lerobot/common/logger.py | 2 +- lerobot/common/robot_devices/control_utils.py | 6 +++--- lerobot/scripts/train_hilserl_classifier.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index dec8b465..4015492d 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -25,13 +25,13 @@ from glob import glob from pathlib import Path import torch +import wandb from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from omegaconf import DictConfig, OmegaConf from termcolor import colored from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler -import wandb from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import get_global_random_state, set_global_random_state diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 911a265b..8a6bcfbd 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -122,12 +122,12 @@ def predict_action(observation, policy, device, use_amp): def init_keyboard_listener(assign_rewards=False): """ - Initializes a keyboard listener to enable early termination of an episode - or environment reset by pressing the right arrow key ('->'). This may require + Initializes a keyboard listener to enable early termination of an episode + or environment reset by pressing the right arrow key ('->'). This may require sudo permissions to allow the terminal to monitor keyboard events. Args: - assign_rewards (bool): If True, allows annotating the collected trajectory + assign_rewards (bool): If True, allows annotating the collected trajectory with a binary reward at the end of the episode to indicate success. """ events = {} diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 8dea68c6..86fa90f2 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -22,6 +22,7 @@ from pprint import pformat import hydra import torch import torch.nn as nn +import wandb from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from termcolor import colored @@ -30,7 +31,6 @@ from torch.cuda.amp import GradScaler from torch.utils.data import DataLoader, WeightedRandomSampler, random_split from tqdm import tqdm -import wandb from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.logger import Logger From 668d493bf997b7d08178d4288a7177d71bb808cf Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Wed, 11 Dec 2024 00:22:10 +0100 Subject: [PATCH 010/112] Update lerobot/scripts/train_hilserl_classifier.py Co-authored-by: Yoel --- lerobot/scripts/train_hilserl_classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 86fa90f2..78659dc8 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -170,7 +170,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l return accuracy, eval_info -@hydra.main(version_base="1.2", config_path="../configs", config_name="classifier") +@hydra.main(version_base="1.2", config_path="../configs", config_name="hilserl_classifier") def train(cfg: DictConfig) -> None: # Main training pipeline with support for resuming training logging.info(OmegaConf.to_yaml(cfg)) From ed66c92383da2bb297d76ae488cd178d8642b252 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Wed, 11 Dec 2024 00:30:33 +0100 Subject: [PATCH 011/112] nit in control_robot.py --- .../policies/hilserl/configuration_hilserl.py | 23 +++++++++++++++ .../policies/hilserl/modeling_hilserl.py | 29 +++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 lerobot/common/policies/hilserl/configuration_hilserl.py create mode 100644 lerobot/common/policies/hilserl/modeling_hilserl.py diff --git a/lerobot/common/policies/hilserl/configuration_hilserl.py b/lerobot/common/policies/hilserl/configuration_hilserl.py new file mode 100644 index 00000000..f1bc850f --- /dev/null +++ b/lerobot/common/policies/hilserl/configuration_hilserl.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + + +@dataclass +class HILSerlConfig: + pass diff --git a/lerobot/common/policies/hilserl/modeling_hilserl.py b/lerobot/common/policies/hilserl/modeling_hilserl.py new file mode 100644 index 00000000..236ed433 --- /dev/null +++ b/lerobot/common/policies/hilserl/modeling_hilserl.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin + + +class HILSerlPolicy( + nn.Module, + PyTorchModelHubMixin, + library_name="lerobot", + repo_url="https://github.com/huggingface/lerobot", + tags=["robotics", "hilserl"], +): + pass \ No newline at end of file From c9af8e36a722d95908ffdf173038863a628f17e3 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Thu, 12 Dec 2024 11:45:30 +0100 Subject: [PATCH 012/112] completed losses --- lerobot/common/policies/sac/modeling_sac.py | 187 ++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 lerobot/common/policies/sac/modeling_sac.py diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py new file mode 100644 index 00000000..fb2e5542 --- /dev/null +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque + +import einops + +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from torch import Tensor + +from huggingface_hub import PyTorchModelHubMixin +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.sac.configuration_sac import SACConfig + +class SACPolicy( + nn.Module, + PyTorchModelHubMixin, + library_name="lerobot", + repo_url="https://github.com/huggingface/lerobot", + tags=["robotics", "RL", "SAC"], +): + + def __init__( + self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None + ): + + super().__init__() + + if config is None: + config = SACConfig() + self.config = config + + if config.input_normalization_modes is not None: + self.normalize_inputs = Normalize( + config.input_shapes, config.input_normalization_modes, dataset_stats + ) + else: + self.normalize_inputs = nn.Identity() + self.normalize_targets = Normalize( + config.output_shapes, config.output_normalization_modes, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_shapes, config.output_normalization_modes, dataset_stats + ) + + self.critic_ensemble = ... + self.critic_target = ... + self.actor_network = ... + + self.temperature = ... + + def reset(self): + """ + Clear observation and action queues. Should be called on `env.reset()` + queues are populated during rollout of the policy, they contain the n latest observations and actions + """ + + self._queues = { + "observation.state": deque(maxlen=1), + "action": deque(maxlen=1), + } + if self._use_image: + self._queues["observation.image"] = deque(maxlen=1) + if self._use_env_state: + self._queues["observation.environment_state"] = deque(maxlen=1) + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + actions, _ = self.actor_network(batch['observations'])### + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: + """Run the batch through the model and compute the loss. + + Returns a dictionary with loss as a tensor, and other information as native floats. + """ + batch = self.normalize_inputs(batch) + # batch shape is (b, 2, ...) where index 1 returns the current observation and + # the next observation for caluculating the right td index. + actions = batch["action"][:, 0] + rewards = batch["next.reward"][:, 0] + observations = {} + next_observations = {} + for k in batch: + if k.startswith("observation."): + observations[k] = batch[k][:, 0] + next_observations[k] = batch[k][:, 1] + + # perform image augmentation + + # reward bias + # from HIL-SERL code base + # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch + + + # calculate critics loss + # 1- compute actions from policy + action_preds, log_probs = self.actor_network(observations) + # 2- compute q targets + q_targets = self.target_qs(next_observations, action_preds) + + # critics subsample size + min_q = q_targets.min(dim=0) + + # backup entropy + td_target = rewards + self.discount * min_q + + # 3- compute predicted qs + q_preds = self.critic_ensemble(observations, actions) + + # 4- Calculate loss + # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. + critics_loss = ( + F.mse_loss( + q_preds, + einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]), + reduction="none", + ).sum(0) # sum over ensemble + # `q_preds_ensemble` depends on the first observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + # q_targets depends on the reward and the next observations. + * ~batch["next.reward_is_pad"] + * ~batch["observation.state_is_pad"][1:] + ).sum(0).mean() + + # calculate actors loss + # 1- temperature + temperature = self.temperature() + + # 2- get actions (batch_size, action_dim) and log probs (batch_size,) + actions, log_probs = self.actor_network(observations) \ + + # 3- get q-value predictions + with torch.no_grad(): + q_preds = self.critic_ensemble(observations, actions, return_type="mean") + actor_loss = ( + -(q_preds - temperature * log_probs).mean() + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ).mean() + + + # calculate temperature loss + # 1- calculate entropy + entropy = -log_probs.mean() + temperature_loss = temperature * (entropy - self.target_entropy).mean() + + loss = critics_loss + actor_loss + temperature_loss + + return { + "critics_loss": critics_loss.item(), + "actor_loss": actor_loss.item(), + "temperature_loss": temperature_loss.item(), + "temperature": temperature.item(), + "entropy": entropy.item(), + "loss": loss, + + } + + def update(self): + self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight) + #for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): + # target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight) + +class SACObservationEncoder(nn.Module): + """Encode image and/or state vector observations.""" + + def __init__(self, config: SACConfig): + + super().__init__() + self.config = config From def42ff4874bfeda79b8d9746c858789d7fd81fb Mon Sep 17 00:00:00 2001 From: KeWang Date: Tue, 17 Dec 2024 13:26:17 +0000 Subject: [PATCH 013/112] Port SAC WIP (#581) Co-authored-by: KeWang1017 --- .../common/policies/sac/configuration_sac.py | 39 ++ lerobot/common/policies/sac/modeling_sac.py | 508 +++++++++++++++++- 2 files changed, 541 insertions(+), 6 deletions(-) create mode 100644 lerobot/common/policies/sac/configuration_sac.py diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py new file mode 100644 index 00000000..441b3566 --- /dev/null +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + + +@dataclass +class SACConfig: + discount = 0.99 + temperature_init = 1.0 + num_critics = 2 + critic_lr = 3e-4 + actor_lr = 3e-4 + critic_network_kwargs = { + "hidden_dims": [256, 256], + "activate_final": True, + } + actor_network_kwargs = { + "hidden_dims": [256, 256], + "activate_final": True, + } + policy_kwargs = { + "tanh_squash_distribution": True, + "std_parameterization": "uniform", + } diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index fb2e5542..9ea9449d 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -15,7 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO: (1) better device management + from collections import deque +from copy import deepcopy +from functools import partial import einops @@ -27,6 +31,10 @@ from torch import Tensor from huggingface_hub import PyTorchModelHubMixin from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.sac.configuration_sac import SACConfig +import numpy as np +from typing import Callable, Optional, Tuple, Sequence + + class SACPolicy( nn.Module, @@ -58,12 +66,27 @@ class SACPolicy( self.unnormalize_outputs = Unnormalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) + encoder = SACObservationEncoder(config) + # Define networks + critic_nets = [] + for _ in range(config.num_critics): + critic_net = Critic( + encoder=encoder, + network=MLP(**config.critic_network_kwargs) + ) + critic_nets.append(critic_net) - self.critic_ensemble = ... - self.critic_target = ... - self.actor_network = ... + self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) + self.critic_target = deepcopy(self.critic_ensemble) - self.temperature = ... + self.actor_network = Policy( + encoder=encoder, + network=MLP(**config.actor_network_kwargs), + action_dim=config.output_shapes["action"][0], + **config.policy_kwargs + ) + + self.temperature = LagrangeMultiplier(init_value=config.temperature_init) def reset(self): """ @@ -178,10 +201,483 @@ class SACPolicy( #for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): # target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight) + +class MLP(nn.Module): + def __init__( + self, + config: SACConfig, + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: Optional[float] = None, + ): + super().__init__() + self.activate_final = config.activate_final + layers = [] + + for i, size in enumerate(config.network_hidden_dims): + layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size)) + + if i + 1 < len(config.network_hidden_dims) or activate_final: + if dropout_rate is not None and dropout_rate > 0: + layers.append(nn.Dropout(p=dropout_rate)) + layers.append(nn.LayerNorm(size)) + layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) + + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor: + # in training mode or not. TODO: find better way to do this + self.train(train) + return self.net(x) + + +class Critic(nn.Module): + def __init__( + self, + encoder: Optional[nn.Module], + network: nn.Module, + init_final: Optional[float] = None, + activate_final: bool = False, + device: str = "cuda" + ): + super().__init__() + self.device = torch.device(device) + self.encoder = encoder + self.network = network + self.init_final = init_final + self.activate_final = activate_final + + # Output layer + if init_final is not None: + if self.activate_final: + self.output_layer = nn.Linear(network.net[-3].out_features, 1) + else: + self.output_layer = nn.Linear(network.net[-2].out_features, 1) + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + if self.activate_final: + self.output_layer = nn.Linear(network.net[-3].out_features, 1) + else: + self.output_layer = nn.Linear(network.net[-2].out_features, 1) + orthogonal_init()(self.output_layer.weight) + + self.to(self.device) + + def forward( + self, + observations: torch.Tensor, + actions: torch.Tensor, + train: bool = False + ) -> torch.Tensor: + self.train(train) + + observations = observations.to(self.device) + actions = actions.to(self.device) + + if self.encoder is not None: + obs_enc = self.encoder(observations) + else: + obs_enc = observations + + inputs = torch.cat([obs_enc, actions], dim=-1) + x = self.network(inputs) + value = self.output_layer(x) + return value.squeeze(-1) + + def q_value_ensemble( + self, + observations: torch.Tensor, + actions: torch.Tensor, + train: bool = False + ) -> torch.Tensor: + observations = observations.to(self.device) + actions = actions.to(self.device) + + if len(actions.shape) == 3: # [batch_size, num_actions, action_dim] + batch_size, num_actions = actions.shape[:2] + obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1) + obs_flat = obs_expanded.reshape(-1, observations.shape[-1]) + actions_flat = actions.reshape(-1, actions.shape[-1]) + q_values = self(obs_flat, actions_flat, train) + return q_values.reshape(batch_size, num_actions) + else: + return self(observations, actions, train) + + +class Policy(nn.Module): + def __init__( + self, + encoder: Optional[nn.Module], + network: nn.Module, + action_dim: int, + std_parameterization: str = "exp", + std_min: float = 1e-5, + std_max: float = 10.0, + tanh_squash_distribution: bool = False, + fixed_std: Optional[torch.Tensor] = None, + init_final: Optional[float] = None, + activate_final: bool = False, + device: str = "cuda" + ): + super().__init__() + self.device = torch.device(device) + self.encoder = encoder + self.network = network + self.action_dim = action_dim + self.std_parameterization = std_parameterization + self.std_min = std_min + self.std_max = std_max + self.tanh_squash_distribution = tanh_squash_distribution + self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None + self.activate_final = activate_final + + # Mean layer + if self.activate_final: + self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim) + else: + self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim) + if init_final is not None: + nn.init.uniform_(self.mean_layer.weight, -init_final, init_final) + nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.mean_layer.weight) + + # Standard deviation layer or parameter + if fixed_std is None: + if std_parameterization == "uniform": + self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device)) + else: + if self.activate_final: + self.std_layer = nn.Linear(network.net[-3].out_features, action_dim) + else: + self.std_layer = nn.Linear(network.net[-2].out_features, action_dim) + if init_final is not None: + nn.init.uniform_(self.std_layer.weight, -init_final, init_final) + nn.init.uniform_(self.std_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.std_layer.weight) + + self.to(self.device) + + def forward( + self, + observations: torch.Tensor, + temperature: float = 1.0, + train: bool = False, + non_squash_distribution: bool = False + ) -> torch.distributions.Distribution: + self.train(train) + + # Encode observations if encoder exists + if self.encoder is not None: + with torch.set_grad_enabled(train): + obs_enc = self.encoder(observations, train=train) + else: + obs_enc = observations + # Get network outputs + outputs = self.network(obs_enc) + means = self.mean_layer(outputs) + + # Compute standard deviations + if self.fixed_std is None: + if self.std_parameterization == "exp": + log_stds = self.std_layer(outputs) + stds = torch.exp(log_stds) + elif self.std_parameterization == "softplus": + stds = torch.nn.functional.softplus(self.std_layer(outputs)) + elif self.std_parameterization == "uniform": + stds = torch.exp(self.log_stds).expand_as(means) + else: + raise ValueError( + f"Invalid std_parameterization: {self.std_parameterization}" + ) + else: + assert self.std_parameterization == "fixed" + stds = self.fixed_std.expand_as(means) + + # Clip standard deviations and scale with temperature + temperature = torch.tensor(temperature, device=self.device) + stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature) + + # Create distribution + if self.tanh_squash_distribution and not non_squash_distribution: + distribution = TanhMultivariateNormalDiag( + loc=means, + scale_diag=stds, + ) + else: + distribution = torch.distributions.Normal( + loc=means, + scale=stds, + ) + + return distribution + + def get_features(self, observations: torch.Tensor) -> torch.Tensor: + """Get encoded features from observations""" + observations = observations.to(self.device) + if self.encoder is not None: + with torch.no_grad(): + return self.encoder(observations, train=False) + return observations + + class SACObservationEncoder(nn.Module): - """Encode image and/or state vector observations.""" + """Encode image and/or state vector observations. + TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders. + """ def __init__(self, config: SACConfig): - + """ + Creates encoders for pixel and/or state modalities. + """ super().__init__() self.config = config + + if "observation.image" in config.input_shapes: + self.image_enc_layers = nn.Sequential( + nn.Conv2d( + config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2 + ), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.ReLU(), + ) + dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) + with torch.inference_mode(): + out_shape = self.image_enc_layers(dummy_batch).shape[1:] + self.image_enc_layers.extend( + nn.Sequential( + nn.Flatten(), + nn.Linear(np.prod(out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Tanh(), + ) + ) + if "observation.state" in config.input_shapes: + self.state_enc_layers = nn.Sequential( + nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim), + nn.ELU(), + nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Tanh(), + ) + if "observation.environment_state" in config.input_shapes: + self.env_state_enc_layers = nn.Sequential( + nn.Linear( + config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim + ), + nn.ELU(), + nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Tanh(), + ) + + def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: + """Encode the image and/or state vector. + + Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken + over all features. + """ + feat = [] + # Concatenate all images along the channel dimension. + image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")] + for image_key in image_keys: + feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])) + if "observation.environment_state" in self.config.input_shapes: + feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) + if "observation.state" in self.config.input_shapes: + feat.append(self.state_enc_layers(obs_dict["observation.state"])) + return torch.stack(feat, dim=0).mean(0) + + +class LagrangeMultiplier(nn.Module): + def __init__( + self, + init_value: float = 1.0, + constraint_shape: Sequence[int] = (), + device: str = "cuda" + ): + super().__init__() + self.device = torch.device(device) + init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1) + + # Initialize the Lagrange multiplier as a parameter + self.lagrange = nn.Parameter( + torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device) + ) + + self.to(self.device) + + def forward( + self, + lhs: Optional[torch.Tensor] = None, + rhs: Optional[torch.Tensor] = None + ) -> torch.Tensor: + # Get the multiplier value based on parameterization + multiplier = torch.nn.functional.softplus(self.lagrange) + + # Return the raw multiplier if no constraint values provided + if lhs is None: + return multiplier + + # Move inputs to device + lhs = lhs.to(self.device) + if rhs is not None: + rhs = rhs.to(self.device) + + # Use the multiplier to compute the Lagrange penalty + if rhs is None: + rhs = torch.zeros_like(lhs, device=self.device) + + diff = lhs - rhs + + assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}" + + return multiplier * diff + + +# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where: +# 1. The base distribution is a diagonal multivariate normal distribution +# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1 +# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation +# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces +class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): + def __init__( + self, + loc: torch.Tensor, + scale_diag: torch.Tensor, + low: Optional[torch.Tensor] = None, + high: Optional[torch.Tensor] = None, + ): + # Create base normal distribution + base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag) + + # Create list of transforms + transforms = [] + + # Add tanh transform + transforms.append(torch.distributions.transforms.TanhTransform()) + + # Add rescaling transform if bounds are provided + if low is not None and high is not None: + transforms.append( + torch.distributions.transforms.AffineTransform( + loc=(high + low) / 2, + scale=(high - low) / 2 + ) + ) + + # Initialize parent class + super().__init__( + base_distribution=base_distribution, + transforms=transforms + ) + + # Store parameters + self.loc = loc + self.scale_diag = scale_diag + self.low = low + self.high = high + + def mode(self) -> torch.Tensor: + """Get the mode of the transformed distribution""" + # The mode of a normal distribution is its mean + mode = self.loc + + # Apply transforms + for transform in self.transforms: + mode = transform(mode) + + return mode + + def rsample(self, sample_shape=torch.Size()) -> torch.Tensor: + """ + Reparameterized sample from the distribution + """ + # Sample from base distribution + x = self.base_dist.rsample(sample_shape) + + # Apply transforms + for transform in self.transforms: + x = transform(x) + + return x + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + """ + Compute log probability of a value + Includes the log det jacobian for the transforms + """ + # Initialize log prob + log_prob = torch.zeros_like(value[..., 0]) + + # Inverse transforms to get back to normal distribution + q = value + for transform in reversed(self.transforms): + q = transform.inv(q) + log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q)) + + # Add base distribution log prob + log_prob = log_prob + self.base_dist.log_prob(q).sum(-1) + + return log_prob + + def sample_and_log_prob(self, sample_shape=torch.Size()) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Sample from the distribution and compute log probability + """ + x = self.rsample(sample_shape) + log_prob = self.log_prob(x) + return x, log_prob + + def entropy(self) -> torch.Tensor: + """ + Compute entropy of the distribution + """ + # Start with base distribution entropy + entropy = self.base_dist.entropy().sum(-1) + + # Add log det jacobian for each transform + x = self.rsample() + for transform in self.transforms: + entropy = entropy + transform.log_abs_det_jacobian(x, transform(x)) + x = transform(x) + + return entropy + + +def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList: + """Creates an ensemble of critic networks""" + critics = nn.ModuleList([critic_class() for _ in range(num_critics)]) + return critics.to(device) + + +def orthogonal_init(): + return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) + + +# borrowed from tdmpc +def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: + """Helper to temporarily flatten extra dims at the start of the image tensor. + + Args: + fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return + (B, *), where * is any number of dimensions. + image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and + can be more than 1 dimensions, generally different from *. + Returns: + A return value from the callable reshaped to (**, *). + """ + if image_tensor.ndim == 4: + return fn(image_tensor) + start_dims = image_tensor.shape[:-3] + inp = torch.flatten(image_tensor, end_dim=-4) + flat_out = fn(inp) + return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) + From 7e0f20fbf285418a78b8619107371f2f0a6c7fd1 Mon Sep 17 00:00:00 2001 From: KeWang1017 Date: Tue, 17 Dec 2024 15:58:04 +0000 Subject: [PATCH 014/112] Enhance SAC configuration and policy with new parameters and subsampling logic - Added `num_subsample_critics`, `critic_target_update_weight`, and `utd_ratio` to SACConfig. - Implemented target entropy calculation in SACPolicy if not provided. - Introduced subsampling of critics to prevent overfitting during updates. - Updated temperature loss calculation to use the new target entropy. - Added comments for future UTD update implementation. These changes improve the flexibility and performance of the SAC implementation. --- .../common/policies/sac/configuration_sac.py | 3 +++ lerobot/common/policies/sac/modeling_sac.py | 21 +++++++++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 441b3566..d324462e 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -23,8 +23,11 @@ class SACConfig: discount = 0.99 temperature_init = 1.0 num_critics = 2 + num_subsample_critics = None critic_lr = 3e-4 actor_lr = 3e-4 + critic_target_update_weight = 0.005 + utd_ratio = 2 critic_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 9ea9449d..7d451b4e 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -85,7 +85,8 @@ class SACPolicy( action_dim=config.output_shapes["action"][0], **config.policy_kwargs ) - + if config.target_entropy is None: + config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) self.temperature = LagrangeMultiplier(init_value=config.temperature_init) def reset(self): @@ -127,7 +128,6 @@ class SACPolicy( # perform image augmentation # reward bias - # from HIL-SERL code base # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch @@ -136,11 +136,16 @@ class SACPolicy( action_preds, log_probs = self.actor_network(observations) # 2- compute q targets q_targets = self.target_qs(next_observations, action_preds) + # subsample critics to prevent overfitting if use high UTD (update to date) + if self.config.num_subsample_critics is not None: + indices = torch.randperm(self.config.num_critics) + indices = indices[:self.config.num_subsample_critics] + q_targets = q_targets[indices] # critics subsample size min_q = q_targets.min(dim=0) - # backup entropy + # compute td target td_target = rewards + self.discount * min_q # 3- compute predicted qs @@ -182,7 +187,10 @@ class SACPolicy( # calculate temperature loss # 1- calculate entropy entropy = -log_probs.mean() - temperature_loss = temperature * (entropy - self.target_entropy).mean() + temperature_loss = self.temp( + lhs=entropy, + rhs=self.config.target_entropy + ) loss = critics_loss + actor_loss + temperature_loss @@ -198,6 +206,11 @@ class SACPolicy( def update(self): self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight) + # TODO: implement UTD update + #for critic_step in range(self.config.utd_ratio - 1): + # only update critic and critic target + # Then update critic, critic target, actor and temperature + #for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): # target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight) From 7b68bfb73b61f6fb90cac3d46a724274a0f184c7 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Tue, 17 Dec 2024 18:03:46 +0100 Subject: [PATCH 015/112] added comments from kewang --- lerobot/common/policies/sac/modeling_sac.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 7d451b4e..de8283de 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -128,6 +128,7 @@ class SACPolicy( # perform image augmentation # reward bias + # from HIL-SERL code base # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch @@ -207,6 +208,7 @@ class SACPolicy( def update(self): self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight) # TODO: implement UTD update + # First update only critics for utd_ratio-1 times #for critic_step in range(self.config.utd_ratio - 1): # only update critic and critic target # Then update critic, critic target, actor and temperature From 70b652f791b515ea325692439615d366f3712dce Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Mon, 23 Dec 2024 16:43:55 +0700 Subject: [PATCH 016/112] [Port Hil-SERL] Add unit tests for the reward classifier & fix imports & check script (#578) --- .../classifier/configuration_classifier.py | 2 +- .../hilserl/classifier/modeling_classifier.py | 8 + poetry.lock | 153 ++++++++++- pyproject.toml | 3 + tests/conftest.py | 13 + .../check_hiserl_reward_classifier.py | 244 ++++++++++++++++++ .../classifier/test_modelling_classifier.py | 78 ++++++ 7 files changed, 499 insertions(+), 2 deletions(-) create mode 100644 tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py create mode 100644 tests/policies/hilserl/classifier/test_modelling_classifier.py diff --git a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py index 209ff659..553e4262 100644 --- a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py @@ -13,7 +13,7 @@ class ClassifierConfig: hidden_dim: int = 256 dropout_rate: float = 0.1 model_name: str = "microsoft/resnet-50" - device: str = "cuda" if torch.cuda.is_available() else "mps" + device: str = "cpu" model_type: str = "cnn" # "transformer" or "cnn" def save_pretrained(self, save_dir): diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index dbb434a7..0b8d66ac 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -22,6 +22,11 @@ class ClassifierOutput: self.probabilities = probabilities self.hidden_states = hidden_states + def __repr__(self): + return (f"ClassifierOutput(logits={self.logits}, " + f"probabilities={self.probabilities}, " + f"hidden_states={self.hidden_states})") + class Classifier( nn.Module, @@ -69,6 +74,8 @@ class Classifier( self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension else: raise ValueError("Unsupported CNN architecture") + + self.encoder = self.encoder.to(self.config.device) def _freeze_encoder(self) -> None: """Freeze the encoder parameters.""" @@ -93,6 +100,7 @@ class Classifier( nn.ReLU(), nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes), ) + self.classifier_head = self.classifier_head.to(self.config.device) def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor: """Extract the appropriate output from the encoder.""" diff --git a/poetry.lock b/poetry.lock index 8799e67c..919edd18 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3139,6 +3139,27 @@ dev = ["changelist (==0.5)"] lint = ["pre-commit (==3.7.0)"] test = ["pytest (>=7.4)", "pytest-cov (>=4.1)"] +[[package]] +name = "lightning-utilities" +version = "0.11.9" +description = "Lightning toolbox for across the our ecosystem." +optional = true +python-versions = ">=3.8" +files = [ + {file = "lightning_utilities-0.11.9-py3-none-any.whl", hash = "sha256:ac6d4e9e28faf3ff4be997876750fee10dc604753dbc429bf3848a95c5d7e0d2"}, + {file = "lightning_utilities-0.11.9.tar.gz", hash = "sha256:f5052b81344cc2684aa9afd74b7ce8819a8f49a858184ec04548a5a109dfd053"}, +] + +[package.dependencies] +packaging = ">=17.1" +setuptools = "*" +typing-extensions = "*" + +[package.extras] +cli = ["fire"] +docs = ["requests (>=2.0.0)"] +typing = ["mypy (>=1.0.0)", "types-setuptools"] + [[package]] name = "llvmlite" version = "0.43.0" @@ -6798,6 +6819,38 @@ webencodings = ">=0.4" doc = ["sphinx", "sphinx_rtd_theme"] test = ["pytest", "ruff"] +[[package]] +name = "tokenizers" +version = "0.21.0" +description = "" +optional = true +python-versions = ">=3.7" +files = [ + {file = "tokenizers-0.21.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2"}, + {file = "tokenizers-0.21.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:f53ea537c925422a2e0e92a24cce96f6bc5046bbef24a1652a5edc8ba975f62e"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b177fb54c4702ef611de0c069d9169f0004233890e0c4c5bd5508ae05abf193"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b43779a269f4629bebb114e19c3fca0223296ae9fea8bb9a7a6c6fb0657ff8e"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aeb255802be90acfd363626753fda0064a8df06031012fe7d52fd9a905eb00e"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8b09dbeb7a8d73ee204a70f94fc06ea0f17dcf0844f16102b9f414f0b7463ba"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:400832c0904f77ce87c40f1a8a27493071282f785724ae62144324f171377273"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84ca973b3a96894d1707e189c14a774b701596d579ffc7e69debfc036a61a04"}, + {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:eb7202d231b273c34ec67767378cd04c767e967fda12d4a9e36208a34e2f137e"}, + {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:089d56db6782a73a27fd8abf3ba21779f5b85d4a9f35e3b493c7bbcbbf0d539b"}, + {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:c87ca3dc48b9b1222d984b6b7490355a6fdb411a2d810f6f05977258400ddb74"}, + {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4145505a973116f91bc3ac45988a92e618a6f83eb458f49ea0790df94ee243ff"}, + {file = "tokenizers-0.21.0-cp39-abi3-win32.whl", hash = "sha256:eb1702c2f27d25d9dd5b389cc1f2f51813e99f8ca30d9e25348db6585a97e24a"}, + {file = "tokenizers-0.21.0-cp39-abi3-win_amd64.whl", hash = "sha256:87841da5a25a3a5f70c102de371db120f41873b854ba65e52bccd57df5a3780c"}, + {file = "tokenizers-0.21.0.tar.gz", hash = "sha256:ee0894bf311b75b0c03079f33859ae4b2334d675d4e93f5a4132e1eae2834fe4"}, +] + +[package.dependencies] +huggingface-hub = ">=0.16.4,<1.0" + +[package.extras] +dev = ["tokenizers[testing]"] +docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] +testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"] + [[package]] name = "tomli" version = "2.0.2" @@ -6863,6 +6916,34 @@ typing-extensions = ">=4.8.0" opt-einsum = ["opt-einsum (>=3.3)"] optree = ["optree (>=0.11.0)"] +[[package]] +name = "torchmetrics" +version = "1.6.0" +description = "PyTorch native Metrics" +optional = true +python-versions = ">=3.9" +files = [ + {file = "torchmetrics-1.6.0-py3-none-any.whl", hash = "sha256:a508cdd87766cedaaf55a419812bf9f493aff8fffc02cc19df5a8e2e7ccb942a"}, + {file = "torchmetrics-1.6.0.tar.gz", hash = "sha256:aebba248708fb90def20cccba6f55bddd134a58de43fb22b0c5ca0f3a89fa984"}, +] + +[package.dependencies] +lightning-utilities = ">=0.8.0" +numpy = ">1.20.0" +packaging = ">17.1" +torch = ">=2.0.0" + +[package.extras] +all = ["SciencePlots (>=2.0.0)", "gammatone (>=1.0.0)", "ipadic (>=1.0.0)", "librosa (>=0.10.0)", "matplotlib (>=3.6.0)", "mecab-python3 (>=1.0.6)", "mypy (==1.13.0)", "nltk (>3.8.1)", "numpy (<2.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "torch (==2.5.1)", "torch-fidelity (<=0.4.0)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +audio = ["gammatone (>=1.0.0)", "librosa (>=0.10.0)", "numpy (<2.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "pystoi (>=0.4.0)", "requests (>=2.19.0)", "torchaudio (>=2.0.1)"] +detection = ["pycocotools (>2.0.0)", "torchvision (>=0.15.1)"] +dev = ["PyTDC (==0.4.1)", "SciencePlots (>=2.0.0)", "bert-score (==0.3.13)", "dython (==0.7.6)", "dython (>=0.7.8,<0.8.0)", "fairlearn", "fast-bss-eval (>=0.1.0)", "faster-coco-eval (>=1.6.3)", "gammatone (>=1.0.0)", "huggingface-hub (<0.27)", "ipadic (>=1.0.0)", "jiwer (>=2.3.0)", "kornia (>=0.6.7)", "librosa (>=0.10.0)", "lpips (<=0.1.4)", "matplotlib (>=3.6.0)", "mecab-ko (>=1.0.0,<1.1.0)", "mecab-ko-dic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "mir-eval (>=0.6)", "monai (==1.3.2)", "monai (==1.4.0)", "mypy (==1.13.0)", "netcal (>1.0.0)", "nltk (>3.8.1)", "numpy (<2.0)", "numpy (<2.2.0)", "onnxruntime (>=1.12.0)", "pandas (>1.4.0)", "permetrics (==2.0.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "pytorch-msssim (==1.0.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "rouge-score (>0.1.0)", "sacrebleu (>=2.3.0)", "scikit-image (>=0.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "sewar (>=0.4.4)", "statsmodels (>0.13.5)", "torch (==2.5.1)", "torch-complex (<0.5.0)", "torch-fidelity (<=0.4.0)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +image = ["scipy (>1.0.0)", "torch-fidelity (<=0.4.0)", "torchvision (>=0.15.1)"] +multimodal = ["piq (<=0.8.0)", "transformers (>=4.42.3)"] +text = ["ipadic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "nltk (>3.8.1)", "regex (>=2021.9.24)", "sentencepiece (>=0.2.0)", "tqdm (<4.68.0)", "transformers (>4.4.0)"] +typing = ["mypy (==1.13.0)", "torch (==2.5.1)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.6.0)"] + [[package]] name = "torchvision" version = "0.19.1" @@ -6956,6 +7037,75 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "transformers" +version = "4.47.0" +description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +optional = true +python-versions = ">=3.9.0" +files = [ + {file = "transformers-4.47.0-py3-none-any.whl", hash = "sha256:a8e1bafdaae69abdda3cad638fe392e37c86d2ce0ecfcae11d60abb8f949ff4d"}, + {file = "transformers-4.47.0.tar.gz", hash = "sha256:f8ead7a5a4f6937bb507e66508e5e002dc5930f7b6122a9259c37b099d0f3b19"}, +] + +[package.dependencies] +filelock = "*" +huggingface-hub = ">=0.24.0,<1.0" +numpy = ">=1.17" +packaging = ">=20.0" +pyyaml = ">=5.1" +regex = "!=2019.12.17" +requests = "*" +safetensors = ">=0.4.1" +tokenizers = ">=0.21,<0.22" +tqdm = ">=4.27" + +[package.extras] +accelerate = ["accelerate (>=0.26.0)"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision"] +audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +benchmark = ["optimum-benchmark (>=0.3.0)"] +codecarbon = ["codecarbon (==1.2.0)"] +deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"] +flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +ftfy = ["ftfy"] +integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] +ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] +modelcreation = ["cookiecutter (==1.7.3)"] +natten = ["natten (>=0.14.6,<0.15.0)"] +onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] +onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] +optuna = ["optuna"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"] +ray = ["ray[tune] (>=2.7.0)"] +retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] +ruff = ["ruff (==0.5.1)"] +sagemaker = ["sagemaker (>=2.31.0)"] +sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] +serving = ["fastapi", "pydantic", "starlette", "uvicorn"] +sigopt = ["sigopt"] +sklearn = ["scikit-learn"] +speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +tiktoken = ["blobfile", "tiktoken"] +timm = ["timm (<=1.0.11)"] +tokenizers = ["tokenizers (>=0.21,<0.22)"] +torch = ["accelerate (>=0.26.0)", "torch"] +torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.24.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch", "tqdm (>=4.27)"] +video = ["av (==9.2.0)"] +vision = ["Pillow (>=10.0.1,<=15.0)"] + [[package]] name = "transforms3d" version = "0.4.2" @@ -7558,6 +7708,7 @@ dev = ["debugpy", "pre-commit"] dora = ["gym-dora"] dynamixel = ["dynamixel-sdk", "pynput"] feetech = ["feetech-servo-sdk", "pynput"] +hilserl = ["torchmetrics", "transformers"] intelrealsense = ["pyrealsense2"] pusht = ["gym-pusht"] stretch = ["hello-robot-stretch-body", "pynput", "pyrealsense2", "pyrender"] @@ -7569,4 +7720,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "41344f0eb2d06d9a378abcd10df8205aa3926ff0a08ac5ab1a0b1bcae7440fd8" +content-hash = "b9d299916ced6af1d243f961a32b0a4aacbef18e0b95337a5224e8511f5d6dda" diff --git a/pyproject.toml b/pyproject.toml index 59c2de8b..738903bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,8 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true} pyserial = {version = ">=3.5", optional = true} jsonlines = ">=4.0.0" +transformers = {version = "^4.47.0", optional = true} +torchmetrics = {version = "^1.6.0", optional = true} [tool.poetry.extras] @@ -86,6 +88,7 @@ dynamixel = ["dynamixel-sdk", "pynput"] feetech = ["feetech-servo-sdk", "pynput"] intelrealsense = ["pyrealsense2"] stretch = ["hello-robot-stretch-body", "pyrender", "pyrealsense2", "pynput"] +hilserl = ["transformers", "torchmetrics"] [tool.ruff] line-length = 110 diff --git a/tests/conftest.py b/tests/conftest.py index 2075c2aa..adf050aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import traceback import pytest +import torch from serial import SerialException from lerobot import available_cameras, available_motors, available_robots @@ -124,3 +126,14 @@ def patch_builtins_input(monkeypatch): print(text) monkeypatch.setattr("builtins.input", print_text) + + +def pytest_addoption(parser): + parser.addoption("--seed", action="store", default="42", help="Set random seed for reproducibility") + + +@pytest.fixture(autouse=True) +def set_random_seed(request): + seed = int(request.config.getoption("--seed")) + random.seed(seed) # Python random + torch.manual_seed(seed) # PyTorch diff --git a/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py b/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py new file mode 100644 index 00000000..55e6e381 --- /dev/null +++ b/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import numpy as np +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss +from torch.optim import Adam +from torch.utils.data import DataLoader +from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall +from torchvision.datasets import CIFAR10 +from torchvision.transforms import ToTensor + +from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier, ClassifierConfig + +BATCH_SIZE = 1000 +LR = 0.1 +EPOCH_NUM = 2 + +if torch.cuda.is_available(): + DEVICE = torch.device("cuda") +elif torch.backends.mps.is_available(): + DEVICE = torch.device("mps") +else: + DEVICE = torch.device("cpu") + + +def train_evaluate_multiclass_classifier(): + logging.info( + f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}" + ) + multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10) + multiclass_classifier = Classifier(multiclass_config) + + trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor()) + testset = CIFAR10(root="data", train=False, download=True, transform=ToTensor()) + + trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) + testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False) + + multiclass_num_classes = 10 + epoch = 1 + + criterion = CrossEntropyLoss() + optimizer = Adam(multiclass_classifier.parameters(), lr=LR) + + multiclass_classifier.train() + + logging.info("Start multiclass classifier training") + + # Training loop + while epoch < EPOCH_NUM: # loop over the dataset multiple times + for i, data in enumerate(trainloader): + inputs, labels = data + inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) + + # Zero the parameter gradients + optimizer.zero_grad() + + # Forward pass + outputs = multiclass_classifier(inputs) + + loss = criterion(outputs.logits, labels) + loss.backward() + optimizer.step() + + if i % 10 == 0: # print every 10 mini-batches + logging.info(f"[Epoch {epoch}, Batch {i}] loss: {loss.item():.3f}") + + epoch += 1 + + print("Multiclass classifier training finished") + + multiclass_classifier.eval() + + test_loss = 0.0 + test_labels = [] + test_pridections = [] + test_probs = [] + + with torch.no_grad(): + for data in testloader: + images, labels = data + images, labels = images.to(DEVICE), labels.to(DEVICE) + outputs = multiclass_classifier(images) + loss = criterion(outputs.logits, labels) + test_loss += loss.item() * BATCH_SIZE + + _, predicted = torch.max(outputs.logits, 1) + test_labels.extend(labels.cpu()) + test_pridections.extend(predicted.cpu()) + test_probs.extend(outputs.probabilities.cpu()) + + test_loss = test_loss / len(testset) + + logging.info(f"Multiclass classifier test loss {test_loss:.3f}") + + test_labels = torch.stack(test_labels) + test_predictions = torch.stack(test_pridections) + test_probs = torch.stack(test_probs) + + accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes) + precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes) + recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes) + f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes) + auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted") + + # Calculate metrics + acc = accuracy(test_predictions, test_labels) + prec = precision(test_predictions, test_labels) + rec = recall(test_predictions, test_labels) + f1_score = f1(test_predictions, test_labels) + auroc_score = auroc(test_probs, test_labels) + + logging.info(f"Accuracy: {acc:.2f}") + logging.info(f"Precision: {prec:.2f}") + logging.info(f"Recall: {rec:.2f}") + logging.info(f"F1 Score: {f1_score:.2f}") + logging.info(f"AUROC Score: {auroc_score:.2f}") + + +def train_evaluate_binary_classifier(): + logging.info( + f"Start binary classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}" + ) + + target_binary_class = 3 + + def one_vs_rest(dataset, target_class): + new_targets = [] + for _, label in dataset: + new_label = float(1.0) if label == target_class else float(0.0) + new_targets.append(new_label) + + dataset.targets = new_targets # Replace the original labels with the binary ones + return dataset + + binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor()) + binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor()) + + # Apply one-vs-rest labeling + binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class) + binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class) + + binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True) + binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False) + + binary_epoch = 1 + + binary_config = ClassifierConfig(model_name="microsoft/resnet-50", device=DEVICE) + binary_classifier = Classifier(binary_config) + + class_counts = np.bincount(binary_train_dataset.targets) + n = len(binary_train_dataset) + w0 = n / (2.0 * class_counts[0]) + w1 = n / (2.0 * class_counts[1]) + + binary_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(w1 / w0)) + binary_optimizer = Adam(binary_classifier.parameters(), lr=LR) + + binary_classifier.train() + + logging.info("Start binary classifier training") + + # Training loop + while binary_epoch < EPOCH_NUM: # loop over the dataset multiple times + for i, data in enumerate(binary_trainloader): + inputs, labels = data + inputs, labels = inputs.to(DEVICE), labels.to(torch.float32).to(DEVICE) + + # Zero the parameter gradients + binary_optimizer.zero_grad() + + # Forward pass + outputs = binary_classifier(inputs) + loss = binary_criterion(outputs.logits, labels) + loss.backward() + binary_optimizer.step() + + if i % 10 == 0: # print every 10 mini-batches + print(f"[Epoch {binary_epoch}, Batch {i}] loss: {loss.item():.3f}") + binary_epoch += 1 + + logging.info("Binary classifier training finished") + logging.info("Start binary classifier evaluation") + + binary_classifier.eval() + + test_loss = 0.0 + test_labels = [] + test_pridections = [] + test_probs = [] + + with torch.no_grad(): + for data in binary_testloader: + images, labels = data + images, labels = images.to(DEVICE), labels.to(torch.float32).to(DEVICE) + outputs = binary_classifier(images) + loss = binary_criterion(outputs.logits, labels) + test_loss += loss.item() * BATCH_SIZE + + test_labels.extend(labels.cpu()) + test_pridections.extend(outputs.logits.cpu()) + test_probs.extend(outputs.probabilities.cpu()) + + test_loss = test_loss / len(binary_test_dataset) + + logging.info(f"Binary classifier test loss {test_loss:.3f}") + + test_labels = torch.stack(test_labels) + test_predictions = torch.stack(test_pridections) + test_probs = torch.stack(test_probs) + + # Calculate metrics + acc = Accuracy(task="binary")(test_predictions, test_labels) + prec = Precision(task="binary", average="weighted")(test_predictions, test_labels) + rec = Recall(task="binary", average="weighted")(test_predictions, test_labels) + f1_score = F1Score(task="binary", average="weighted")(test_predictions, test_labels) + auroc_score = AUROC(task="binary", average="weighted")(test_probs, test_labels) + + logging.info(f"Accuracy: {acc:.2f}") + logging.info(f"Precision: {prec:.2f}") + logging.info(f"Recall: {rec:.2f}") + logging.info(f"F1 Score: {f1_score:.2f}") + logging.info(f"AUROC Score: {auroc_score:.2f}") + + +if __name__ == "__main__": + train_evaluate_multiclass_classifier() + train_evaluate_binary_classifier() diff --git a/tests/policies/hilserl/classifier/test_modelling_classifier.py b/tests/policies/hilserl/classifier/test_modelling_classifier.py new file mode 100644 index 00000000..014165eb --- /dev/null +++ b/tests/policies/hilserl/classifier/test_modelling_classifier.py @@ -0,0 +1,78 @@ +import torch + +from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( + Classifier, + ClassifierConfig, + ClassifierOutput, +) +from tests.utils import require_package + + +def test_classifier_output(): + output = ClassifierOutput( + logits=torch.tensor([1, 2, 3]), probabilities=torch.tensor([0.1, 0.2, 0.3]), hidden_states=None + ) + + assert ( + f"{output}" + == "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)" + ) + + +@require_package("transformers") +def test_binary_classifier_with_default_params(): + config = ClassifierConfig() + classifier = Classifier(config) + + batch_size = 10 + + input = torch.rand(batch_size, 3, 224, 224) + output = classifier(input) + + assert output is not None + assert output.logits.shape == torch.Size([batch_size]) + assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" + assert output.probabilities.shape == torch.Size([batch_size]) + assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" + assert output.hidden_states.shape == torch.Size([batch_size, 2048]) + assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" + + +@require_package("transformers") +def test_multiclass_classifier(): + num_classes = 5 + config = ClassifierConfig(num_classes=num_classes) + classifier = Classifier(config) + + batch_size = 10 + + input = torch.rand(batch_size, 3, 224, 224) + output = classifier(input) + + assert output is not None + assert output.logits.shape == torch.Size([batch_size, num_classes]) + assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" + assert output.probabilities.shape == torch.Size([batch_size, num_classes]) + assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" + assert output.hidden_states.shape == torch.Size([batch_size, 2048]) + assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" + + +@require_package("transformers") +def test_default_device(): + config = ClassifierConfig() + assert config.device == "cpu" + + classifier = Classifier(config) + for p in classifier.parameters(): + assert p.device == torch.device("cpu") + + +@require_package("transformers") +def test_explicit_device_setup(): + config = ClassifierConfig(device="meta") + assert config.device == "meta" + + classifier = Classifier(config) + for p in classifier.parameters(): + assert p.device == torch.device("meta") From b53d6e0ff254d17aa8e4e1639cfc6aea899e3df6 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Mon, 23 Dec 2024 16:44:29 +0700 Subject: [PATCH 017/112] [HIL-SERL PORT] Fix linter issues (#588) --- lerobot/common/policies/sac/modeling_sac.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index de8283de..c5e3f209 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -19,21 +19,18 @@ from collections import deque from copy import deepcopy -from functools import partial +from typing import Callable, Optional, Sequence, Tuple import einops - +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 +from huggingface_hub import PyTorchModelHubMixin from torch import Tensor -from huggingface_hub import PyTorchModelHubMixin from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.sac.configuration_sac import SACConfig -import numpy as np -from typing import Callable, Optional, Tuple, Sequence - class SACPolicy( @@ -290,10 +287,7 @@ class Critic(nn.Module): observations = observations.to(self.device) actions = actions.to(self.device) - if self.encoder is not None: - obs_enc = self.encoder(observations) - else: - obs_enc = observations + obs_enc = observations if self.encoder is None else self.encoder(observations) inputs = torch.cat([obs_enc, actions], dim=-1) x = self.network(inputs) @@ -563,6 +557,8 @@ class LagrangeMultiplier(nn.Module): # 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation # This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): + DEFAULT_SAMPLE_SHAPE = torch.Size() + def __init__( self, loc: torch.Tensor, @@ -611,7 +607,7 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): return mode - def rsample(self, sample_shape=torch.Size()) -> torch.Tensor: + def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor: """ Reparameterized sample from the distribution """ @@ -643,7 +639,7 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): return log_prob - def sample_and_log_prob(self, sample_shape=torch.Size()) -> Tuple[torch.Tensor, torch.Tensor]: + def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]: """ Sample from the distribution and compute log probability """ From 08ec971086488277fc8745bc5c11a445e46c51ea Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 23 Dec 2024 14:12:03 +0100 Subject: [PATCH 018/112] added optimizer and sac to factory.py --- lerobot/common/policies/factory.py | 6 ++++++ lerobot/common/policies/sac/configuration_sac.py | 1 + lerobot/scripts/train.py | 9 +++++++++ 3 files changed, 16 insertions(+) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 5cb2fd52..7f550d90 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -66,6 +66,12 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]: from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy return VQBeTPolicy, VQBeTConfig + elif name == "sac": + from lerobot.common.policies.sac.configuration_sac import SACConfig + from lerobot.common.policies.sac.modeling_sac import SACPolicy + + return SACPolicy, SACConfig + else: raise NotImplementedError(f"Policy with name {name} is not implemented.") diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index d324462e..6db198e8 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -26,6 +26,7 @@ class SACConfig: num_subsample_critics = None critic_lr = 3e-4 actor_lr = 3e-4 + temperature_lr = 3e-4 critic_target_update_weight = 0.005 utd_ratio = 2 critic_network_kwargs = { diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 9a0b7e4c..346c3acd 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -93,6 +93,15 @@ def make_optimizer_and_scheduler(cfg, policy): elif policy.name == "tdmpc": optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) lr_scheduler = None + + elif policy.name == "sac": + optimizer = torch.optim.Adam([ + {'params': policy.actor.parameters(), 'lr': policy.config.actor_lr}, + {'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr}, + {'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr}, + ]) + lr_scheduler = None + elif cfg.policy.name == "vqbet": from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler From dc54d357ca9106d72b0d70b064e2740f10b8fc53 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Sun, 29 Dec 2024 12:51:21 +0000 Subject: [PATCH 019/112] Added normalization schemes and style checks --- lerobot/common/logger.py | 2 +- .../classifier/configuration_classifier.py | 2 - .../hilserl/classifier/modeling_classifier.py | 10 +- .../policies/hilserl/configuration_hilserl.py | 2 +- .../policies/hilserl/modeling_hilserl.py | 4 +- .../common/policies/sac/configuration_sac.py | 42 +++- lerobot/common/policies/sac/modeling_sac.py | 220 ++++++++---------- lerobot/scripts/eval_on_robot.py | 8 +- lerobot/scripts/train.py | 14 +- lerobot/scripts/train_hilserl_classifier.py | 2 +- 10 files changed, 150 insertions(+), 156 deletions(-) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 4015492d..dec8b465 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -25,13 +25,13 @@ from glob import glob from pathlib import Path import torch -import wandb from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from omegaconf import DictConfig, OmegaConf from termcolor import colored from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler +import wandb from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import get_global_random_state, set_global_random_state diff --git a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py index 553e4262..f0b9352f 100644 --- a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py @@ -2,8 +2,6 @@ import json import os from dataclasses import asdict, dataclass -import torch - @dataclass class ClassifierConfig: diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index 0b8d66ac..28b05744 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -23,9 +23,11 @@ class ClassifierOutput: self.hidden_states = hidden_states def __repr__(self): - return (f"ClassifierOutput(logits={self.logits}, " - f"probabilities={self.probabilities}, " - f"hidden_states={self.hidden_states})") + return ( + f"ClassifierOutput(logits={self.logits}, " + f"probabilities={self.probabilities}, " + f"hidden_states={self.hidden_states})" + ) class Classifier( @@ -74,7 +76,7 @@ class Classifier( self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension else: raise ValueError("Unsupported CNN architecture") - + self.encoder = self.encoder.to(self.config.device) def _freeze_encoder(self) -> None: diff --git a/lerobot/common/policies/hilserl/configuration_hilserl.py b/lerobot/common/policies/hilserl/configuration_hilserl.py index f1bc850f..80d2f578 100644 --- a/lerobot/common/policies/hilserl/configuration_hilserl.py +++ b/lerobot/common/policies/hilserl/configuration_hilserl.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2024 The HuggingFace Inc. team. +# Copyright 2024 The HuggingFace Inc. team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/lerobot/common/policies/hilserl/modeling_hilserl.py b/lerobot/common/policies/hilserl/modeling_hilserl.py index 236ed433..679eb010 100644 --- a/lerobot/common/policies/hilserl/modeling_hilserl.py +++ b/lerobot/common/policies/hilserl/modeling_hilserl.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2024 The HuggingFace Inc. team. +# Copyright 2024 The HuggingFace Inc. team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,4 +26,4 @@ class HILSerlPolicy( repo_url="https://github.com/huggingface/lerobot", tags=["robotics", "hilserl"], ): - pass \ No newline at end of file + pass diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 6db198e8..f4a2bc4c 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2024 The HuggingFace Inc. team. +# Copyright 2024 The HuggingFace Inc. team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, field @dataclass @@ -30,14 +30,36 @@ class SACConfig: critic_target_update_weight = 0.005 utd_ratio = 2 critic_network_kwargs = { - "hidden_dims": [256, 256], - "activate_final": True, - } + "hidden_dims": [256, 256], + "activate_final": True, + } actor_network_kwargs = { - "hidden_dims": [256, 256], - "activate_final": True, - } + "hidden_dims": [256, 256], + "activate_final": True, + } policy_kwargs = { - "tanh_squash_distribution": True, - "std_parameterization": "uniform", + "tanh_squash_distribution": True, + "std_parameterization": "uniform", + } + + input_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "observation.image": [3, 84, 84], + "observation.state": [4], } + ) + output_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "action": [4], + } + ) + + state_encoder_hidden_dim: int = 256 + latent_dim: int = 256 + network_hidden_dims: int = 256 + + # Normalization / Unnormalization + input_normalization_modes: dict[str, str] | None = None + output_normalization_modes: dict[str, str] = field( + default_factory=lambda: {"action": "min_max"}, + ) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index c5e3f209..51258fac 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2024 The HuggingFace Inc. team. +# Copyright 2024 The HuggingFace Inc. team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -40,11 +40,9 @@ class SACPolicy( repo_url="https://github.com/huggingface/lerobot", tags=["robotics", "RL", "SAC"], ): - def __init__( self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None ): - super().__init__() if config is None: @@ -67,12 +65,9 @@ class SACPolicy( # Define networks critic_nets = [] for _ in range(config.num_critics): - critic_net = Critic( - encoder=encoder, - network=MLP(**config.critic_network_kwargs) - ) + critic_net = Critic(encoder=encoder, network=MLP(**config.critic_network_kwargs)) critic_nets.append(critic_net) - + self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) self.critic_target = deepcopy(self.critic_ensemble) @@ -80,11 +75,11 @@ class SACPolicy( encoder=encoder, network=MLP(**config.actor_network_kwargs), action_dim=config.output_shapes["action"][0], - **config.policy_kwargs + **config.policy_kwargs, ) if config.target_entropy is None: - config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) - self.temperature = LagrangeMultiplier(init_value=config.temperature_init) + config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) + self.temperature = LagrangeMultiplier(init_value=config.temperature_init) def reset(self): """ @@ -100,10 +95,10 @@ class SACPolicy( self._queues["observation.image"] = deque(maxlen=1) if self._use_env_state: self._queues["observation.environment_state"] = deque(maxlen=1) - + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: - actions, _ = self.actor_network(batch['observations'])### + actions, _ = self.actor_network(batch["observations"]) ### def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: """Run the batch through the model and compute the loss. @@ -111,8 +106,8 @@ class SACPolicy( Returns a dictionary with loss as a tensor, and other information as native floats. """ batch = self.normalize_inputs(batch) - # batch shape is (b, 2, ...) where index 1 returns the current observation and - # the next observation for caluculating the right td index. + # batch shape is (b, 2, ...) where index 1 returns the current observation and + # the next observation for caluculating the right td index. actions = batch["action"][:, 0] rewards = batch["next.reward"][:, 0] observations = {} @@ -121,13 +116,12 @@ class SACPolicy( if k.startswith("observation."): observations[k] = batch[k][:, 0] next_observations[k] = batch[k][:, 1] - + # perform image augmentation # reward bias - # from HIL-SERL code base + # from HIL-SERL code base # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch - # calculate critics loss # 1- compute actions from policy @@ -137,7 +131,7 @@ class SACPolicy( # subsample critics to prevent overfitting if use high UTD (update to date) if self.config.num_subsample_critics is not None: indices = torch.randperm(self.config.num_critics) - indices = indices[:self.config.num_subsample_critics] + indices = indices[: self.config.num_subsample_critics] q_targets = q_targets[indices] # critics subsample size @@ -151,8 +145,9 @@ class SACPolicy( # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - critics_loss = ( - F.mse_loss( + critics_loss = ( + ( + F.mse_loss( q_preds, einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]), reduction="none", @@ -163,15 +158,17 @@ class SACPolicy( # q_targets depends on the reward and the next observations. * ~batch["next.reward_is_pad"] * ~batch["observation.state_is_pad"][1:] - ).sum(0).mean() - + ) + .sum(0) + .mean() + ) + # calculate actors loss # 1- temperature temperature = self.temperature() # 2- get actions (batch_size, action_dim) and log probs (batch_size,) - actions, log_probs = self.actor_network(observations) \ - + actions, log_probs = self.actor_network(observations) # 3- get q-value predictions with torch.no_grad(): q_preds = self.critic_ensemble(observations, actions, return_type="mean") @@ -181,36 +178,31 @@ class SACPolicy( * ~batch["action_is_pad"] ).mean() - # calculate temperature loss # 1- calculate entropy entropy = -log_probs.mean() - temperature_loss = self.temp( - lhs=entropy, - rhs=self.config.target_entropy - ) + temperature_loss = self.temp(lhs=entropy, rhs=self.config.target_entropy) loss = critics_loss + actor_loss + temperature_loss return { - "critics_loss": critics_loss.item(), - "actor_loss": actor_loss.item(), - "temperature_loss": temperature_loss.item(), - "temperature": temperature.item(), - "entropy": entropy.item(), - "loss": loss, + "critics_loss": critics_loss.item(), + "actor_loss": actor_loss.item(), + "temperature_loss": temperature_loss.item(), + "temperature": temperature.item(), + "entropy": entropy.item(), + "loss": loss, + } - } - def update(self): self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight) # TODO: implement UTD update # First update only critics for utd_ratio-1 times - #for critic_step in range(self.config.utd_ratio - 1): - # only update critic and critic target + # for critic_step in range(self.config.utd_ratio - 1): + # only update critic and critic target # Then update critic, critic target, actor and temperature - #for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): + # for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): # target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight) @@ -225,24 +217,28 @@ class MLP(nn.Module): super().__init__() self.activate_final = config.activate_final layers = [] - + for i, size in enumerate(config.network_hidden_dims): - layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size)) - + layers.append( + nn.Linear(config.network_hidden_dims[i - 1] if i > 0 else config.network_hidden_dims[0], size) + ) + if i + 1 < len(config.network_hidden_dims) or activate_final: if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.LayerNorm(size)) - layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) - + layers.append( + activations if isinstance(activations, nn.Module) else getattr(nn, activations)() + ) + self.net = nn.Sequential(*layers) def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor: # in training mode or not. TODO: find better way to do this - self.train(train) + self.train(train) return self.net(x) - - + + class Critic(nn.Module): def __init__( self, @@ -250,7 +246,7 @@ class Critic(nn.Module): network: nn.Module, init_final: Optional[float] = None, activate_final: bool = False, - device: str = "cuda" + device: str = "cuda", ): super().__init__() self.device = torch.device(device) @@ -258,7 +254,7 @@ class Critic(nn.Module): self.network = network self.init_final = init_final self.activate_final = activate_final - + # Output layer if init_final is not None: if self.activate_final: @@ -273,36 +269,28 @@ class Critic(nn.Module): else: self.output_layer = nn.Linear(network.net[-2].out_features, 1) orthogonal_init()(self.output_layer.weight) - + self.to(self.device) - def forward( - self, - observations: torch.Tensor, - actions: torch.Tensor, - train: bool = False - ) -> torch.Tensor: + def forward(self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False) -> torch.Tensor: self.train(train) - + observations = observations.to(self.device) actions = actions.to(self.device) - + obs_enc = observations if self.encoder is None else self.encoder(observations) - + inputs = torch.cat([obs_enc, actions], dim=-1) x = self.network(inputs) value = self.output_layer(x) return value.squeeze(-1) - + def q_value_ensemble( - self, - observations: torch.Tensor, - actions: torch.Tensor, - train: bool = False + self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False ) -> torch.Tensor: observations = observations.to(self.device) actions = actions.to(self.device) - + if len(actions.shape) == 3: # [batch_size, num_actions, action_dim] batch_size, num_actions = actions.shape[:2] obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1) @@ -327,7 +315,7 @@ class Policy(nn.Module): fixed_std: Optional[torch.Tensor] = None, init_final: Optional[float] = None, activate_final: bool = False, - device: str = "cuda" + device: str = "cuda", ): super().__init__() self.device = torch.device(device) @@ -340,7 +328,7 @@ class Policy(nn.Module): self.tanh_squash_distribution = tanh_squash_distribution self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.activate_final = activate_final - + # Mean layer if self.activate_final: self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim) @@ -351,7 +339,7 @@ class Policy(nn.Module): nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) else: orthogonal_init()(self.mean_layer.weight) - + # Standard deviation layer or parameter if fixed_std is None: if std_parameterization == "uniform": @@ -366,18 +354,18 @@ class Policy(nn.Module): nn.init.uniform_(self.std_layer.bias, -init_final, init_final) else: orthogonal_init()(self.std_layer.weight) - + self.to(self.device) def forward( - self, + self, observations: torch.Tensor, temperature: float = 1.0, train: bool = False, - non_squash_distribution: bool = False + non_squash_distribution: bool = False, ) -> torch.distributions.Distribution: self.train(train) - + # Encode observations if encoder exists if self.encoder is not None: with torch.set_grad_enabled(train): @@ -387,7 +375,7 @@ class Policy(nn.Module): # Get network outputs outputs = self.network(obs_enc) means = self.mean_layer(outputs) - + # Compute standard deviations if self.fixed_std is None: if self.std_parameterization == "exp": @@ -398,9 +386,7 @@ class Policy(nn.Module): elif self.std_parameterization == "uniform": stds = torch.exp(self.log_stds).expand_as(means) else: - raise ValueError( - f"Invalid std_parameterization: {self.std_parameterization}" - ) + raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}") else: assert self.std_parameterization == "fixed" stds = self.fixed_std.expand_as(means) @@ -422,7 +408,7 @@ class Policy(nn.Module): ) return distribution - + def get_features(self, observations: torch.Tensor) -> torch.Tensor: """Get encoded features from observations""" observations = observations.to(self.device) @@ -503,56 +489,47 @@ class SACObservationEncoder(nn.Module): if "observation.state" in self.config.input_shapes: feat.append(self.state_enc_layers(obs_dict["observation.state"])) return torch.stack(feat, dim=0).mean(0) - + class LagrangeMultiplier(nn.Module): - def __init__( - self, - init_value: float = 1.0, - constraint_shape: Sequence[int] = (), - device: str = "cuda" - ): + def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"): super().__init__() self.device = torch.device(device) init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1) - + # Initialize the Lagrange multiplier as a parameter self.lagrange = nn.Parameter( torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device) ) - + self.to(self.device) - def forward( - self, - lhs: Optional[torch.Tensor] = None, - rhs: Optional[torch.Tensor] = None - ) -> torch.Tensor: - # Get the multiplier value based on parameterization + def forward(self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None) -> torch.Tensor: + # Get the multiplier value based on parameterization multiplier = torch.nn.functional.softplus(self.lagrange) - + # Return the raw multiplier if no constraint values provided if lhs is None: return multiplier - + # Move inputs to device lhs = lhs.to(self.device) if rhs is not None: rhs = rhs.to(self.device) - + # Use the multiplier to compute the Lagrange penalty if rhs is None: rhs = torch.zeros_like(lhs, device=self.device) - + diff = lhs - rhs - + assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}" - + return multiplier * diff # The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where: -# 1. The base distribution is a diagonal multivariate normal distribution +# 1. The base distribution is a diagonal multivariate normal distribution # 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1 # 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation # This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces @@ -568,28 +545,22 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): ): # Create base normal distribution base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag) - + # Create list of transforms transforms = [] - + # Add tanh transform transforms.append(torch.distributions.transforms.TanhTransform()) - + # Add rescaling transform if bounds are provided if low is not None and high is not None: transforms.append( - torch.distributions.transforms.AffineTransform( - loc=(high + low) / 2, - scale=(high - low) / 2 - ) + torch.distributions.transforms.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2) ) - + # Initialize parent class - super().__init__( - base_distribution=base_distribution, - transforms=transforms - ) - + super().__init__(base_distribution=base_distribution, transforms=transforms) + # Store parameters self.loc = loc self.scale_diag = scale_diag @@ -600,11 +571,11 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): """Get the mode of the transformed distribution""" # The mode of a normal distribution is its mean mode = self.loc - + # Apply transforms for transform in self.transforms: mode = transform(mode) - + return mode def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor: @@ -613,11 +584,11 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): """ # Sample from base distribution x = self.base_dist.rsample(sample_shape) - + # Apply transforms for transform in self.transforms: x = transform(x) - + return x def log_prob(self, value: torch.Tensor) -> torch.Tensor: @@ -627,16 +598,16 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): """ # Initialize log prob log_prob = torch.zeros_like(value[..., 0]) - + # Inverse transforms to get back to normal distribution q = value for transform in reversed(self.transforms): q = transform.inv(q) log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q)) - + # Add base distribution log prob log_prob = log_prob + self.base_dist.log_prob(q).sum(-1) - + return log_prob def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]: @@ -653,13 +624,13 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): """ # Start with base distribution entropy entropy = self.base_dist.entropy().sum(-1) - + # Add log det jacobian for each transform x = self.rsample() for transform in self.transforms: entropy = entropy + transform.log_abs_det_jacobian(x, transform(x)) x = transform(x) - + return entropy @@ -680,7 +651,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens Args: fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return (B, *), where * is any number of dimensions. - image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and + image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and can be more than 1 dimensions, generally different from *. Returns: A return value from the callable reshaped to (**, *). @@ -691,4 +662,3 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens inp = torch.flatten(image_tensor, end_dim=-4) flat_out = fn(inp) return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) - diff --git a/lerobot/scripts/eval_on_robot.py b/lerobot/scripts/eval_on_robot.py index 6a790f0a..92daa860 100644 --- a/lerobot/scripts/eval_on_robot.py +++ b/lerobot/scripts/eval_on_robot.py @@ -24,7 +24,7 @@ python lerobot/scripts/eval_on_robot.py \ ``` **NOTE** (michel-aractingi): This script is incomplete and it is being prepared -for running training on the real robot. +for running training on the real robot. """ import argparse @@ -47,7 +47,7 @@ from lerobot.common.utils.utils import ( def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict: - """Run a batched policy rollout on the real robot. + """Run a batched policy rollout on the real robot. The return dictionary contains: "robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation @@ -64,7 +64,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, extraneous elements from the sequences above. Args: - robot: The robot class that defines the interface with the real robot. + robot: The robot class that defines the interface with the real robot. policy: The policy. Must be a PyTorch nn module. Returns: @@ -77,7 +77,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, listener, events = init_keyboard_listener() # Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready. - # policy.reset() + # policy.reset() # Get observation from real robot observation = robot.capture_observation() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 346c3acd..fbe7927d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -95,12 +95,14 @@ def make_optimizer_and_scheduler(cfg, policy): lr_scheduler = None elif policy.name == "sac": - optimizer = torch.optim.Adam([ - {'params': policy.actor.parameters(), 'lr': policy.config.actor_lr}, - {'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr}, - {'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr}, - ]) - lr_scheduler = None + optimizer = torch.optim.Adam( + [ + {"params": policy.actor.parameters(), "lr": policy.config.actor_lr}, + {"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr}, + {"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr}, + ] + ) + lr_scheduler = None elif cfg.policy.name == "vqbet": from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 78659dc8..ea8336a9 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -22,7 +22,6 @@ from pprint import pformat import hydra import torch import torch.nn as nn -import wandb from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from termcolor import colored @@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler from torch.utils.data import DataLoader, WeightedRandomSampler, random_split from tqdm import tqdm +import wandb from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.logger import Logger From 18a45989861a3c737dfd92d612271040fd897c19 Mon Sep 17 00:00:00 2001 From: KeWang1017 Date: Thu, 26 Dec 2024 23:38:46 +0000 Subject: [PATCH 020/112] trying to get sac running --- .../common/policies/sac/configuration_sac.py | 21 +++++ lerobot/common/policies/sac/modeling_sac.py | 79 ++++++++-------- .../configs/policy/sac_pusht_keypoints.yaml | 89 +++++++++++++++++++ 3 files changed, 149 insertions(+), 40 deletions(-) create mode 100644 lerobot/configs/policy/sac_pusht_keypoints.yaml diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index f4a2bc4c..6df94761 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -20,6 +20,24 @@ from dataclasses import dataclass, field @dataclass class SACConfig: + input_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "observation.image": [3, 84, 84], + "observation.state": [4], + } + ) + output_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "action": [4], + } + ) + + # Normalization / Unnormalization + input_normalization_modes: dict[str, str] | None = None + output_normalization_modes: dict[str, str] = field( + default_factory=lambda: {"action": "min_max"}, + ) + discount = 0.99 temperature_init = 1.0 num_critics = 2 @@ -29,6 +47,9 @@ class SACConfig: temperature_lr = 3e-4 critic_target_update_weight = 0.005 utd_ratio = 2 + state_encoder_hidden_dim = 256 + latent_dim = 50 + target_entropy = None critic_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 51258fac..87170d20 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -40,6 +40,8 @@ class SACPolicy( repo_url="https://github.com/huggingface/lerobot", tags=["robotics", "RL", "SAC"], ): + name = "sac" + def __init__( self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None ): @@ -71,7 +73,7 @@ class SACPolicy( self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) self.critic_target = deepcopy(self.critic_ensemble) - self.actor_network = Policy( + self.actor = Policy( encoder=encoder, network=MLP(**config.actor_network_kwargs), action_dim=config.output_shapes["action"][0], @@ -91,14 +93,14 @@ class SACPolicy( "observation.state": deque(maxlen=1), "action": deque(maxlen=1), } - if self._use_image: + if "observation.image" in self.config.input_shapes: self._queues["observation.image"] = deque(maxlen=1) - if self._use_env_state: + if "observation.environment_state" in self.config.input_shapes: self._queues["observation.environment_state"] = deque(maxlen=1) @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: - actions, _ = self.actor_network(batch["observations"]) ### + actions, _ = self.actor(batch['observations']) def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: """Run the batch through the model and compute the loss. @@ -119,19 +121,18 @@ class SACPolicy( # perform image augmentation - # reward bias - # from HIL-SERL code base + # reward bias from HIL-SERL code base # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch # calculate critics loss # 1- compute actions from policy - action_preds, log_probs = self.actor_network(observations) + action_preds, log_probs = self.actor(observations) # 2- compute q targets q_targets = self.target_qs(next_observations, action_preds) # subsample critics to prevent overfitting if use high UTD (update to date) if self.config.num_subsample_critics is not None: indices = torch.randperm(self.config.num_critics) - indices = indices[: self.config.num_subsample_critics] + indices = indices[:self.config.num_subsample_critics] q_targets = q_targets[indices] # critics subsample size @@ -168,7 +169,8 @@ class SACPolicy( temperature = self.temperature() # 2- get actions (batch_size, action_dim) and log probs (batch_size,) - actions, log_probs = self.actor_network(observations) + actions, log_probs = self.actor(observations) \ + # 3- get q-value predictions with torch.no_grad(): q_preds = self.critic_ensemble(observations, actions, return_type="mean") @@ -209,21 +211,19 @@ class SACPolicy( class MLP(nn.Module): def __init__( self, - config: SACConfig, + hidden_dims: list[int], activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), activate_final: bool = False, dropout_rate: Optional[float] = None, ): super().__init__() - self.activate_final = config.activate_final + self.activate_final = activate_final layers = [] - - for i, size in enumerate(config.network_hidden_dims): - layers.append( - nn.Linear(config.network_hidden_dims[i - 1] if i > 0 else config.network_hidden_dims[0], size) - ) - - if i + 1 < len(config.network_hidden_dims) or activate_final: + + for i, size in enumerate(hidden_dims): + layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size)) + + if i + 1 < len(hidden_dims) or activate_final: if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.LayerNorm(size)) @@ -254,20 +254,20 @@ class Critic(nn.Module): self.network = network self.init_final = init_final self.activate_final = activate_final - + + # Find the last Linear layer's output dimension + for layer in reversed(network.net): + if isinstance(layer, nn.Linear): + out_features = layer.out_features + break + # Output layer if init_final is not None: - if self.activate_final: - self.output_layer = nn.Linear(network.net[-3].out_features, 1) - else: - self.output_layer = nn.Linear(network.net[-2].out_features, 1) + self.output_layer = nn.Linear(out_features, 1) nn.init.uniform_(self.output_layer.weight, -init_final, init_final) nn.init.uniform_(self.output_layer.bias, -init_final, init_final) else: - if self.activate_final: - self.output_layer = nn.Linear(network.net[-3].out_features, 1) - else: - self.output_layer = nn.Linear(network.net[-2].out_features, 1) + self.output_layer = nn.Linear(out_features, 1) orthogonal_init()(self.output_layer.weight) self.to(self.device) @@ -328,12 +328,15 @@ class Policy(nn.Module): self.tanh_squash_distribution = tanh_squash_distribution self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.activate_final = activate_final - + + # Find the last Linear layer's output dimension + for layer in reversed(network.net): + if isinstance(layer, nn.Linear): + out_features = layer.out_features + break + # Mean layer - if self.activate_final: - self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim) - else: - self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim) + self.mean_layer = nn.Linear(out_features, action_dim) if init_final is not None: nn.init.uniform_(self.mean_layer.weight, -init_final, init_final) nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) @@ -345,10 +348,7 @@ class Policy(nn.Module): if std_parameterization == "uniform": self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device)) else: - if self.activate_final: - self.std_layer = nn.Linear(network.net[-3].out_features, action_dim) - else: - self.std_layer = nn.Linear(network.net[-2].out_features, action_dim) + self.std_layer = nn.Linear(out_features, action_dim) if init_final is not None: nn.init.uniform_(self.std_layer.weight, -init_final, init_final) nn.init.uniform_(self.std_layer.bias, -init_final, init_final) @@ -571,7 +571,6 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): """Get the mode of the transformed distribution""" # The mode of a normal distribution is its mean mode = self.loc - # Apply transforms for transform in self.transforms: mode = transform(mode) @@ -634,10 +633,10 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): return entropy -def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList: +def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList: """Creates an ensemble of critic networks""" - critics = nn.ModuleList([critic_class() for _ in range(num_critics)]) - return critics.to(device) + assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}" + return nn.ModuleList(critics).to(device) def orthogonal_init(): diff --git a/lerobot/configs/policy/sac_pusht_keypoints.yaml b/lerobot/configs/policy/sac_pusht_keypoints.yaml new file mode 100644 index 00000000..19af60d4 --- /dev/null +++ b/lerobot/configs/policy/sac_pusht_keypoints.yaml @@ -0,0 +1,89 @@ +# @package _global_ + +# Train with: +# +# python lerobot/scripts/train.py \ +# env=pusht \ +# +dataset=lerobot/pusht_keypoints + +seed: 1 +dataset_repo_id: lerobot/pusht_keypoints + +training: + offline_steps: 0 + + # Offline training dataloader + num_workers: 4 + + batch_size: 128 + grad_clip_norm: 10.0 + lr: 3e-4 + + eval_freq: 10000 + log_freq: 500 + save_freq: 50000 + + online_steps: 1000000 + online_rollout_n_episodes: 10 + online_rollout_batch_size: 10 + online_steps_between_rollouts: 1000 + online_sampling_ratio: 1.0 + online_env_seed: 10000 + online_buffer_capacity: 40000 + online_buffer_seed_size: 0 + do_online_rollout_async: false + + delta_timestamps: + observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + action: "[i / ${fps} for i in range(${policy.horizon})]" + next.reward: "[i / ${fps} for i in range(${policy.horizon})]" + +policy: + name: sac + + pretrained_model_path: + + # Input / output structure. + n_action_repeats: 1 + horizon: 5 + n_action_steps: 5 + + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.environment_state: [16] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: + observation.environment_state: min_max + observation.state: min_max + output_normalization_modes: + action: min_max + + # Architecture / modeling. + # Neural networks. + # image_encoder_hidden_dim: 32 + discount: 0.99 + temperature_init: 1.0 + num_critics: 2 + num_subsample_critics: None + critic_lr: 3e-4 + actor_lr: 3e-4 + temperature_lr: 3e-4 + critic_target_update_weight: 0.005 + utd_ratio: 2 + + + # # Loss coefficients. + # reward_coeff: 0.5 + # expectile_weight: 0.9 + # value_coeff: 0.1 + # consistency_coeff: 20.0 + # advantage_scaling: 3.0 + # pi_coeff: 0.5 + # temporal_decay_coeff: 0.5 + # # Target model. + # target_model_momentum: 0.995 From ca74a13d616bdb9df6a8c2b3544837d73c2641a4 Mon Sep 17 00:00:00 2001 From: KeWang1017 Date: Sat, 28 Dec 2024 18:07:15 +0000 Subject: [PATCH 021/112] Refactor SACPolicy for improved action sampling and standard deviation handling - Updated action selection to use distribution sampling and log probabilities for better stochastic behavior. - Enhanced standard deviation clamping to prevent extreme values, ensuring stability in policy outputs. - Cleaned up code by removing unnecessary comments and improving readability. These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference. --- lerobot/common/policies/sac/modeling_sac.py | 75 ++++++++++++++------- 1 file changed, 52 insertions(+), 23 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 87170d20..821cb93f 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -19,6 +19,7 @@ from collections import deque from copy import deepcopy +import math from typing import Callable, Optional, Sequence, Tuple import einops @@ -100,7 +101,12 @@ class SACPolicy( @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: - actions, _ = self.actor(batch['observations']) + """Select action for inference/evaluation""" + distribution = self.actor(batch) + # Sample from the distribution and return just the actions + actions = distribution.mode() # or distribution.sample() for stochastic actions + actions = self.unnormalize_outputs({"action": actions})["action"] + return actions def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: """Run the batch through the model and compute the loss. @@ -126,7 +132,10 @@ class SACPolicy( # calculate critics loss # 1- compute actions from policy - action_preds, log_probs = self.actor(observations) + distribution = self.actor(observations) + action_preds = distribution.sample() + log_probs = distribution.log_prob(action_preds) + action_preds = torch.clamp(action_preds, -1, +1) # 2- compute q targets q_targets = self.target_qs(next_observations, action_preds) # subsample critics to prevent overfitting if use high UTD (update to date) @@ -146,31 +155,46 @@ class SACPolicy( # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - critics_loss = ( - ( - F.mse_loss( - q_preds, - einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]), - reduction="none", - ).sum(0) # sum over ensemble - # `q_preds_ensemble` depends on the first observation and the actions. - * ~batch["observation.state_is_pad"][0] - * ~batch["action_is_pad"] - # q_targets depends on the reward and the next observations. - * ~batch["next.reward_is_pad"] - * ~batch["observation.state_is_pad"][1:] - ) - .sum(0) - .mean() - ) + #critics_loss = ( + # ( + # F.mse_loss( + # q_preds, + # einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]), + # reduction="none", + # ).sum(0) # sum over ensemble + # # `q_preds_ensemble` depends on the first observation and the actions. + # * ~batch["observation.state_is_pad"][0] + # * ~batch["action_is_pad"] + # # q_targets depends on the reward and the next observations. + # * ~batch["next.reward_is_pad"] + # * ~batch["observation.state_is_pad"][1:] + # ) + # .sum(0) + # .mean() + #) + # 4- Calculate loss + # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. + critics_loss = F.mse_loss( + q_preds, # shape: [num_critics, batch_size] + einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape + reduction="none" + ).sum(0).mean() + # breakpoint() # calculate actors loss # 1- temperature temperature = self.temperature() # 2- get actions (batch_size, action_dim) and log probs (batch_size,) +<<<<<<< HEAD actions, log_probs = self.actor(observations) \ +======= + distribution = self.actor(observations) + actions = distribution.sample() + log_probs = distribution.log_prob(actions) + actions = torch.clamp(actions, -1, +1) +>>>>>>> d3c62b92 (Refactor SACPolicy for improved action sampling and standard deviation handling) # 3- get q-value predictions with torch.no_grad(): q_preds = self.critic_ensemble(observations, actions, return_type="mean") @@ -309,8 +333,8 @@ class Policy(nn.Module): network: nn.Module, action_dim: int, std_parameterization: str = "exp", - std_min: float = 1e-5, - std_max: float = 10.0, + std_min: float = 0.05, + std_max: float = 2.0, tanh_squash_distribution: bool = False, fixed_std: Optional[torch.Tensor] = None, init_final: Optional[float] = None, @@ -372,6 +396,7 @@ class Policy(nn.Module): obs_enc = self.encoder(observations, train=train) else: obs_enc = observations + # Get network outputs outputs = self.network(obs_enc) means = self.mean_layer(outputs) @@ -380,18 +405,22 @@ class Policy(nn.Module): if self.fixed_std is None: if self.std_parameterization == "exp": log_stds = self.std_layer(outputs) + # Clamp log_stds to prevent too large or small values + log_stds = torch.clamp(log_stds, math.log(self.std_min), math.log(self.std_max)) stds = torch.exp(log_stds) elif self.std_parameterization == "softplus": stds = torch.nn.functional.softplus(self.std_layer(outputs)) + stds = torch.clamp(stds, self.std_min, self.std_max) elif self.std_parameterization == "uniform": - stds = torch.exp(self.log_stds).expand_as(means) + log_stds = torch.clamp(self.log_stds, math.log(self.std_min), math.log(self.std_max)) + stds = torch.exp(log_stds).expand_as(means) else: raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}") else: assert self.std_parameterization == "fixed" stds = self.fixed_std.expand_as(means) - # Clip standard deviations and scale with temperature + # Scale with temperature temperature = torch.tensor(temperature, device=self.device) stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature) From 22fbc9ea4a8b7d168f8227b463f9270b897fed56 Mon Sep 17 00:00:00 2001 From: KeWang1017 Date: Sat, 28 Dec 2024 22:11:34 +0000 Subject: [PATCH 022/112] Refine SAC configuration and policy for enhanced performance - Updated standard deviation parameterization in SACConfig to 'softplus' with defined min and max values for improved stability. - Modified action sampling in SACPolicy to use reparameterized sampling, ensuring better gradient flow and log probability calculations. - Cleaned up log probability calculations in TanhMultivariateNormalDiag for clarity and efficiency. - Increased evaluation frequency in YAML configuration to 50000 for more efficient training cycles. These changes aim to enhance the robustness and performance of the SAC implementation during training and inference. --- .../common/policies/sac/configuration_sac.py | 12 ++-- lerobot/common/policies/sac/modeling_sac.py | 56 +++++++++---------- .../configs/policy/sac_pusht_keypoints.yaml | 2 +- 3 files changed, 31 insertions(+), 39 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 6df94761..7a4bd364 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -59,14 +59,10 @@ class SACConfig: "activate_final": True, } policy_kwargs = { - "tanh_squash_distribution": True, - "std_parameterization": "uniform", - } - - input_shapes: dict[str, list[int]] = field( - default_factory=lambda: { - "observation.image": [3, 84, 84], - "observation.state": [4], + "tanh_squash_distribution": True, + "std_parameterization": "softplus", + "std_min": 0.005, + "std_max": 5.0, } ) output_shapes: dict[str, list[int]] = field( diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 821cb93f..806cb767 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -134,7 +134,6 @@ class SACPolicy( # 1- compute actions from policy distribution = self.actor(observations) action_preds = distribution.sample() - log_probs = distribution.log_prob(action_preds) action_preds = torch.clamp(action_preds, -1, +1) # 2- compute q targets q_targets = self.target_qs(next_observations, action_preds) @@ -186,15 +185,11 @@ class SACPolicy( temperature = self.temperature() # 2- get actions (batch_size, action_dim) and log probs (batch_size,) -<<<<<<< HEAD - actions, log_probs = self.actor(observations) \ - -======= distribution = self.actor(observations) - actions = distribution.sample() - log_probs = distribution.log_prob(actions) + actions = distribution.rsample() + log_probs = distribution.log_prob(actions).sum(-1) + # breakpoint() actions = torch.clamp(actions, -1, +1) ->>>>>>> d3c62b92 (Refactor SACPolicy for improved action sampling and standard deviation handling) # 3- get q-value predictions with torch.no_grad(): q_preds = self.critic_ensemble(observations, actions, return_type="mean") @@ -610,7 +605,7 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): """ Reparameterized sample from the distribution """ - # Sample from base distribution + # Sample from base distributionrsample x = self.base_dist.rsample(sample_shape) # Apply transforms @@ -625,17 +620,18 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): Includes the log det jacobian for the transforms """ # Initialize log prob - log_prob = torch.zeros_like(value[..., 0]) - + log_prob = torch.zeros_like(value) + # Inverse transforms to get back to normal distribution q = value for transform in reversed(self.transforms): - q = transform.inv(q) - log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q)) - + q_prev = transform.inv(q) # Get the pre-transform value + log_prob = log_prob - transform.log_abs_det_jacobian(q_prev, q) # Sum over action dimensions + q = q_prev + # Add base distribution log prob - log_prob = log_prob + self.base_dist.log_prob(q).sum(-1) - + log_prob = log_prob + self.base_dist.log_prob(q) # Sum over action dimensions + return log_prob def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]: @@ -646,20 +642,20 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): log_prob = self.log_prob(x) return x, log_prob - def entropy(self) -> torch.Tensor: - """ - Compute entropy of the distribution - """ - # Start with base distribution entropy - entropy = self.base_dist.entropy().sum(-1) - - # Add log det jacobian for each transform - x = self.rsample() - for transform in self.transforms: - entropy = entropy + transform.log_abs_det_jacobian(x, transform(x)) - x = transform(x) - - return entropy + # def entropy(self) -> torch.Tensor: + # """ + # Compute entropy of the distribution + # """ + # # Start with base distribution entropy + # entropy = self.base_dist.entropy().sum(-1) + + # # Add log det jacobian for each transform + # x = self.rsample() + # for transform in self.transforms: + # entropy = entropy + transform.log_abs_det_jacobian(x, transform(x)) + # x = transform(x) + + # return entropy def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList: diff --git a/lerobot/configs/policy/sac_pusht_keypoints.yaml b/lerobot/configs/policy/sac_pusht_keypoints.yaml index 19af60d4..6d8971a2 100644 --- a/lerobot/configs/policy/sac_pusht_keypoints.yaml +++ b/lerobot/configs/policy/sac_pusht_keypoints.yaml @@ -19,7 +19,7 @@ training: grad_clip_norm: 10.0 lr: 3e-4 - eval_freq: 10000 + eval_freq: 50000 log_freq: 500 save_freq: 50000 From 5b4adc00bb3da018cf10cbde6e120fd5e890c179 Mon Sep 17 00:00:00 2001 From: KeWang1017 Date: Sun, 29 Dec 2024 12:30:39 +0000 Subject: [PATCH 023/112] Refactor SAC configuration and policy for improved action sampling and stability - Updated SACConfig to replace standard deviation parameterization with log_std_min and log_std_max for better control over action distributions. - Modified SACPolicy to streamline action selection and log probability calculations, enhancing stochastic behavior. - Removed deprecated TanhMultivariateNormalDiag class to simplify the codebase and improve maintainability. These changes aim to enhance the robustness and performance of the SAC implementation during training and inference. --- .../common/policies/sac/configuration_sac.py | 27 +- lerobot/common/policies/sac/modeling_sac.py | 233 +++--------------- 2 files changed, 43 insertions(+), 217 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 7a4bd364..52c564a6 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -53,30 +53,13 @@ class SACConfig: critic_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, - } + } actor_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, - } + } policy_kwargs = { - "tanh_squash_distribution": True, - "std_parameterization": "softplus", - "std_min": 0.005, - "std_max": 5.0, + "use_tanh_squash": True, + "log_std_min": -5, + "log_std_max": 2, } - ) - output_shapes: dict[str, list[int]] = field( - default_factory=lambda: { - "action": [4], - } - ) - - state_encoder_hidden_dim: int = 256 - latent_dim: int = 256 - network_hidden_dims: int = 256 - - # Normalization / Unnormalization - input_normalization_modes: dict[str, str] | None = None - output_normalization_modes: dict[str, str] = field( - default_factory=lambda: {"action": "min_max"}, - ) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 806cb767..1e7fd92b 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -102,9 +102,7 @@ class SACPolicy( @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" - distribution = self.actor(batch) - # Sample from the distribution and return just the actions - actions = distribution.mode() # or distribution.sample() for stochastic actions + actions, _ = self.actor(batch) actions = self.unnormalize_outputs({"action": actions})["action"] return actions @@ -129,12 +127,11 @@ class SACPolicy( # reward bias from HIL-SERL code base # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch - + # calculate critics loss # 1- compute actions from policy - distribution = self.actor(observations) - action_preds = distribution.sample() - action_preds = torch.clamp(action_preds, -1, +1) + action_preds, log_probs = self.actor(next_observations) + # 2- compute q targets q_targets = self.target_qs(next_observations, action_preds) # subsample critics to prevent overfitting if use high UTD (update to date) @@ -147,7 +144,7 @@ class SACPolicy( min_q = q_targets.min(dim=0) # compute td target - td_target = rewards + self.discount * min_q + td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term # 3- compute predicted qs q_preds = self.critic_ensemble(observations, actions) @@ -178,18 +175,12 @@ class SACPolicy( einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape reduction="none" ).sum(0).mean() - # breakpoint() # calculate actors loss # 1- temperature temperature = self.temperature() - # 2- get actions (batch_size, action_dim) and log probs (batch_size,) - distribution = self.actor(observations) - actions = distribution.rsample() - log_probs = distribution.log_prob(actions).sum(-1) - # breakpoint() - actions = torch.clamp(actions, -1, +1) + actions, log_probs = self.actor(observations) # 3- get q-value predictions with torch.no_grad(): q_preds = self.critic_ensemble(observations, actions, return_type="mean") @@ -264,15 +255,13 @@ class Critic(nn.Module): encoder: Optional[nn.Module], network: nn.Module, init_final: Optional[float] = None, - activate_final: bool = False, - device: str = "cuda", + device: str = "cuda" ): super().__init__() self.device = torch.device(device) self.encoder = encoder self.network = network self.init_final = init_final - self.activate_final = activate_final # Find the last Linear layer's output dimension for layer in reversed(network.net): @@ -304,22 +293,6 @@ class Critic(nn.Module): value = self.output_layer(x) return value.squeeze(-1) - def q_value_ensemble( - self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False - ) -> torch.Tensor: - observations = observations.to(self.device) - actions = actions.to(self.device) - - if len(actions.shape) == 3: # [batch_size, num_actions, action_dim] - batch_size, num_actions = actions.shape[:2] - obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1) - obs_flat = obs_expanded.reshape(-1, observations.shape[-1]) - actions_flat = actions.reshape(-1, actions.shape[-1]) - q_values = self(obs_flat, actions_flat, train) - return q_values.reshape(batch_size, num_actions) - else: - return self(observations, actions, train) - class Policy(nn.Module): def __init__( @@ -327,26 +300,22 @@ class Policy(nn.Module): encoder: Optional[nn.Module], network: nn.Module, action_dim: int, - std_parameterization: str = "exp", - std_min: float = 0.05, - std_max: float = 2.0, - tanh_squash_distribution: bool = False, + log_std_min: float = -5, + log_std_max: float = 2, fixed_std: Optional[torch.Tensor] = None, init_final: Optional[float] = None, - activate_final: bool = False, - device: str = "cuda", + use_tanh_squash: bool = False, + device: str = "cuda" ): super().__init__() self.device = torch.device(device) self.encoder = encoder self.network = network self.action_dim = action_dim - self.std_parameterization = std_parameterization - self.std_min = std_min - self.std_max = std_max - self.tanh_squash_distribution = tanh_squash_distribution + self.log_std_min = log_std_min + self.log_std_max = log_std_max self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None - self.activate_final = activate_final + self.use_tanh_squash = use_tanh_squash # Find the last Linear layer's output dimension for layer in reversed(network.net): @@ -364,27 +333,20 @@ class Policy(nn.Module): # Standard deviation layer or parameter if fixed_std is None: - if std_parameterization == "uniform": - self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device)) + self.std_layer = nn.Linear(out_features, action_dim) + if init_final is not None: + nn.init.uniform_(self.std_layer.weight, -init_final, init_final) + nn.init.uniform_(self.std_layer.bias, -init_final, init_final) else: - self.std_layer = nn.Linear(out_features, action_dim) - if init_final is not None: - nn.init.uniform_(self.std_layer.weight, -init_final, init_final) - nn.init.uniform_(self.std_layer.bias, -init_final, init_final) - else: - orthogonal_init()(self.std_layer.weight) - + orthogonal_init()(self.std_layer.weight) + self.to(self.device) def forward( self, observations: torch.Tensor, - temperature: float = 1.0, - train: bool = False, - non_squash_distribution: bool = False, - ) -> torch.distributions.Distribution: - self.train(train) - + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Encode observations if encoder exists if self.encoder is not None: with torch.set_grad_enabled(train): @@ -398,41 +360,24 @@ class Policy(nn.Module): # Compute standard deviations if self.fixed_std is None: - if self.std_parameterization == "exp": - log_stds = self.std_layer(outputs) - # Clamp log_stds to prevent too large or small values - log_stds = torch.clamp(log_stds, math.log(self.std_min), math.log(self.std_max)) - stds = torch.exp(log_stds) - elif self.std_parameterization == "softplus": - stds = torch.nn.functional.softplus(self.std_layer(outputs)) - stds = torch.clamp(stds, self.std_min, self.std_max) - elif self.std_parameterization == "uniform": - log_stds = torch.clamp(self.log_stds, math.log(self.std_min), math.log(self.std_max)) - stds = torch.exp(log_stds).expand_as(means) - else: - raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}") + log_std = self.std_layer(outputs) + if self.use_tanh_squash: + log_std = torch.tanh(log_std) + log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) else: - assert self.std_parameterization == "fixed" stds = self.fixed_std.expand_as(means) - # Scale with temperature - temperature = torch.tensor(temperature, device=self.device) - stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature) - - # Create distribution - if self.tanh_squash_distribution and not non_squash_distribution: - distribution = TanhMultivariateNormalDiag( - loc=means, - scale_diag=stds, - ) - else: - distribution = torch.distributions.Normal( - loc=means, - scale=stds, - ) - - return distribution + # uses tahn activation function to squash the action to be in the range of [-1, 1] + normal = torch.distributions.Normal(means, stds) + x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) + log_probs = normal.log_prob(x_t) + if self.use_tanh_squash: + actions = torch.tanh(x_t) + log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) + log_probs = log_probs.sum(-1) # sum over action dim + return actions, log_probs + def get_features(self, observations: torch.Tensor) -> torch.Tensor: """Get encoded features from observations""" observations = observations.to(self.device) @@ -552,110 +497,8 @@ class LagrangeMultiplier(nn.Module): return multiplier * diff -# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where: -# 1. The base distribution is a diagonal multivariate normal distribution -# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1 -# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation -# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces -class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): - DEFAULT_SAMPLE_SHAPE = torch.Size() - - def __init__( - self, - loc: torch.Tensor, - scale_diag: torch.Tensor, - low: Optional[torch.Tensor] = None, - high: Optional[torch.Tensor] = None, - ): - # Create base normal distribution - base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag) - - # Create list of transforms - transforms = [] - - # Add tanh transform - transforms.append(torch.distributions.transforms.TanhTransform()) - - # Add rescaling transform if bounds are provided - if low is not None and high is not None: - transforms.append( - torch.distributions.transforms.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2) - ) - - # Initialize parent class - super().__init__(base_distribution=base_distribution, transforms=transforms) - - # Store parameters - self.loc = loc - self.scale_diag = scale_diag - self.low = low - self.high = high - - def mode(self) -> torch.Tensor: - """Get the mode of the transformed distribution""" - # The mode of a normal distribution is its mean - mode = self.loc - # Apply transforms - for transform in self.transforms: - mode = transform(mode) - - return mode - - def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor: - """ - Reparameterized sample from the distribution - """ - # Sample from base distributionrsample - x = self.base_dist.rsample(sample_shape) - - # Apply transforms - for transform in self.transforms: - x = transform(x) - - return x - - def log_prob(self, value: torch.Tensor) -> torch.Tensor: - """ - Compute log probability of a value - Includes the log det jacobian for the transforms - """ - # Initialize log prob - log_prob = torch.zeros_like(value) - - # Inverse transforms to get back to normal distribution - q = value - for transform in reversed(self.transforms): - q_prev = transform.inv(q) # Get the pre-transform value - log_prob = log_prob - transform.log_abs_det_jacobian(q_prev, q) # Sum over action dimensions - q = q_prev - - # Add base distribution log prob - log_prob = log_prob + self.base_dist.log_prob(q) # Sum over action dimensions - - return log_prob - - def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Sample from the distribution and compute log probability - """ - x = self.rsample(sample_shape) - log_prob = self.log_prob(x) - return x, log_prob - - # def entropy(self) -> torch.Tensor: - # """ - # Compute entropy of the distribution - # """ - # # Start with base distribution entropy - # entropy = self.base_dist.entropy().sum(-1) - - # # Add log det jacobian for each transform - # x = self.rsample() - # for transform in self.transforms: - # entropy = entropy + transform.log_abs_det_jacobian(x, transform(x)) - # x = transform(x) - - # return entropy +def orthogonal_init(): + return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList: From bae3b02928c7de4d3243eb3fead4c67f236ee167 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Sun, 29 Dec 2024 14:35:21 +0000 Subject: [PATCH 024/112] style fixes --- .../common/policies/sac/configuration_sac.py | 6 +- lerobot/common/policies/sac/modeling_sac.py | 70 +++++++++---------- 2 files changed, 37 insertions(+), 39 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 52c564a6..a324294c 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -53,13 +53,13 @@ class SACConfig: critic_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, - } + } actor_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, - } + } policy_kwargs = { "use_tanh_squash": True, "log_std_min": -5, "log_std_max": 2, - } + } diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 1e7fd92b..9df2c859 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -19,7 +19,6 @@ from collections import deque from copy import deepcopy -import math from typing import Callable, Optional, Sequence, Tuple import einops @@ -125,9 +124,9 @@ class SACPolicy( # perform image augmentation - # reward bias from HIL-SERL code base + # reward bias from HIL-SERL code base # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch - + # calculate critics loss # 1- compute actions from policy action_preds, log_probs = self.actor(next_observations) @@ -137,21 +136,23 @@ class SACPolicy( # subsample critics to prevent overfitting if use high UTD (update to date) if self.config.num_subsample_critics is not None: indices = torch.randperm(self.config.num_critics) - indices = indices[:self.config.num_subsample_critics] + indices = indices[: self.config.num_subsample_critics] q_targets = q_targets[indices] # critics subsample size min_q = q_targets.min(dim=0) # compute td target - td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term + td_target = ( + rewards + self.config.discount * min_q + ) # + self.config.discount * self.temperature() * log_probs # add entropy term # 3- compute predicted qs q_preds = self.critic_ensemble(observations, actions) # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - #critics_loss = ( + # critics_loss = ( # ( # F.mse_loss( # q_preds, @@ -167,14 +168,20 @@ class SACPolicy( # ) # .sum(0) # .mean() - #) + # ) # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - critics_loss = F.mse_loss( - q_preds, # shape: [num_critics, batch_size] - einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape - reduction="none" - ).sum(0).mean() + critics_loss = ( + F.mse_loss( + q_preds, # shape: [num_critics, batch_size] + einops.repeat( + td_target, "b -> e b", e=q_preds.shape[0] + ), # expand td_target to match q_preds shape + reduction="none", + ) + .sum(0) + .mean() + ) # calculate actors loss # 1- temperature @@ -229,10 +236,10 @@ class MLP(nn.Module): super().__init__() self.activate_final = activate_final layers = [] - + for i, size in enumerate(hidden_dims): - layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size)) - + layers.append(nn.Linear(hidden_dims[i - 1] if i > 0 else hidden_dims[0], size)) + if i + 1 < len(hidden_dims) or activate_final: if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) @@ -255,20 +262,20 @@ class Critic(nn.Module): encoder: Optional[nn.Module], network: nn.Module, init_final: Optional[float] = None, - device: str = "cuda" + device: str = "cuda", ): super().__init__() self.device = torch.device(device) self.encoder = encoder self.network = network self.init_final = init_final - + # Find the last Linear layer's output dimension for layer in reversed(network.net): if isinstance(layer, nn.Linear): out_features = layer.out_features break - + # Output layer if init_final is not None: self.output_layer = nn.Linear(out_features, 1) @@ -305,7 +312,7 @@ class Policy(nn.Module): fixed_std: Optional[torch.Tensor] = None, init_final: Optional[float] = None, use_tanh_squash: bool = False, - device: str = "cuda" + device: str = "cuda", ): super().__init__() self.device = torch.device(device) @@ -316,13 +323,13 @@ class Policy(nn.Module): self.log_std_max = log_std_max self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.use_tanh_squash = use_tanh_squash - + # Find the last Linear layer's output dimension for layer in reversed(network.net): if isinstance(layer, nn.Linear): out_features = layer.out_features break - + # Mean layer self.mean_layer = nn.Linear(out_features, action_dim) if init_final is not None: @@ -339,21 +346,16 @@ class Policy(nn.Module): nn.init.uniform_(self.std_layer.bias, -init_final, init_final) else: orthogonal_init()(self.std_layer.weight) - + self.to(self.device) def forward( self, observations: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - # Encode observations if encoder exists - if self.encoder is not None: - with torch.set_grad_enabled(train): - obs_enc = self.encoder(observations, train=train) - else: - obs_enc = observations - + obs_enc = observations if self.encoder is not None else self.encoder(observations) + # Get network outputs outputs = self.network(obs_enc) means = self.mean_layer(outputs) @@ -369,15 +371,15 @@ class Policy(nn.Module): # uses tahn activation function to squash the action to be in the range of [-1, 1] normal = torch.distributions.Normal(means, stds) - x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) + x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) log_probs = normal.log_prob(x_t) if self.use_tanh_squash: actions = torch.tanh(x_t) log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) - log_probs = log_probs.sum(-1) # sum over action dim + log_probs = log_probs.sum(-1) # sum over action dim return actions, log_probs - + def get_features(self, observations: torch.Tensor) -> torch.Tensor: """Get encoded features from observations""" observations = observations.to(self.device) @@ -507,10 +509,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s return nn.ModuleList(critics).to(device) -def orthogonal_init(): - return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) - - # borrowed from tdmpc def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: """Helper to temporarily flatten extra dims at the start of the image tensor. From ee306e2f9b5bdfb5abebcb0228334536f260817d Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Sun, 29 Dec 2024 23:59:39 +0000 Subject: [PATCH 025/112] split encoder for critic and actor --- .../common/policies/sac/configuration_sac.py | 2 +- lerobot/common/policies/sac/modeling_sac.py | 306 ++++++++++-------- 2 files changed, 177 insertions(+), 131 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index a324294c..5f676933 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -48,7 +48,7 @@ class SACConfig: critic_target_update_weight = 0.005 utd_ratio = 2 state_encoder_hidden_dim = 256 - latent_dim = 50 + latent_dim = 128 target_entropy = None critic_network_kwargs = { "hidden_dims": [256, 256], diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 9df2c859..bd77408e 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -63,25 +63,35 @@ class SACPolicy( self.unnormalize_outputs = Unnormalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) - encoder = SACObservationEncoder(config) + encoder_critic = SACObservationEncoder(config) + encoder_actor = SACObservationEncoder(config) # Define networks critic_nets = [] for _ in range(config.num_critics): - critic_net = Critic(encoder=encoder, network=MLP(**config.critic_network_kwargs)) + critic_net = Critic( + encoder=encoder_critic, + network=MLP( + input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], + **config.critic_network_kwargs + ) + ) critic_nets.append(critic_net) self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) self.critic_target = deepcopy(self.critic_ensemble) self.actor = Policy( - encoder=encoder, - network=MLP(**config.actor_network_kwargs), + encoder=encoder_actor, + network=MLP( + input_dim=encoder_actor.output_dim, + **config.actor_network_kwargs + ), action_dim=config.output_shapes["action"][0], - **config.policy_kwargs, + **config.policy_kwargs ) if config.target_entropy is None: - config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) - self.temperature = LagrangeMultiplier(init_value=config.temperature_init) + config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) + self.temperature = LagrangeMultiplier(init_value=config.temperature_init) def reset(self): """ @@ -104,15 +114,31 @@ class SACPolicy( actions, _ = self.actor(batch) actions = self.unnormalize_outputs({"action": actions})["action"] return actions + + def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor: + """Forward pass through a critic network ensemble + + Args: + observations: Dictionary of observations + actions: Action tensor + use_target: If True, use target critics, otherwise use ensemble critics + + Returns: + Tensor of Q-values from all critics + """ + critics = self.critic_target if use_target else self.critic_ensemble + q_values = torch.stack([critic(observations, actions) for critic in critics]) + return q_values + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: """Run the batch through the model and compute the loss. - + Returns a dictionary with loss as a tensor, and other information as native floats. """ batch = self.normalize_inputs(batch) - # batch shape is (b, 2, ...) where index 1 returns the current observation and - # the next observation for caluculating the right td index. + # batch shape is (b, 2, ...) where index 1 returns the current observation and + # the next observation for calculating the right td index. actions = batch["action"][:, 0] rewards = batch["next.reward"][:, 0] observations = {} @@ -121,113 +147,109 @@ class SACPolicy( if k.startswith("observation."): observations[k] = batch[k][:, 0] next_observations[k] = batch[k][:, 1] - + # perform image augmentation - # reward bias from HIL-SERL code base + # reward bias from HIL-SERL code base # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch - + # calculate critics loss # 1- compute actions from policy action_preds, log_probs = self.actor(next_observations) # 2- compute q targets - q_targets = self.target_qs(next_observations, action_preds) + q_targets = self.critic_forward(next_observations, action_preds, use_target=True) + # subsample critics to prevent overfitting if use high UTD (update to date) if self.config.num_subsample_critics is not None: indices = torch.randperm(self.config.num_critics) - indices = indices[: self.config.num_subsample_critics] + indices = indices[:self.config.num_subsample_critics] q_targets = q_targets[indices] # critics subsample size - min_q = q_targets.min(dim=0) + min_q, _ = q_targets.min(dim=0) # Get values from min operation # compute td target - td_target = ( - rewards + self.config.discount * min_q - ) # + self.config.discount * self.temperature() * log_probs # add entropy term + td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term # 3- compute predicted qs - q_preds = self.critic_ensemble(observations, actions) + q_preds = self.critic_forward(observations, actions, use_target=False) # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - # critics_loss = ( - # ( - # F.mse_loss( - # q_preds, - # einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]), - # reduction="none", - # ).sum(0) # sum over ensemble - # # `q_preds_ensemble` depends on the first observation and the actions. - # * ~batch["observation.state_is_pad"][0] - # * ~batch["action_is_pad"] - # # q_targets depends on the reward and the next observations. - # * ~batch["next.reward_is_pad"] - # * ~batch["observation.state_is_pad"][1:] - # ) - # .sum(0) - # .mean() - # ) - # 4- Calculate loss - # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - critics_loss = ( - F.mse_loss( - q_preds, # shape: [num_critics, batch_size] - einops.repeat( - td_target, "b -> e b", e=q_preds.shape[0] - ), # expand td_target to match q_preds shape - reduction="none", - ) - .sum(0) - .mean() - ) + critics_loss = F.mse_loss( + q_preds, # shape: [num_critics, batch_size] + einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape + reduction="none" + ).sum(0).mean() + # critics_loss = ( + # F.mse_loss( + # q_preds, + # einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), + # reduction="none", + # ).sum(0) # sum over ensemble + # # `q_preds_ensemble` depends on the first observation and the actions. + # * ~batch["observation.state_is_pad"][0] + # * ~batch["action_is_pad"] + # # q_targets depends on the reward and the next observations. + # * ~batch["next.reward_is_pad"] + # * ~batch["observation.state_is_pad"][1:] + # ).sum(0).mean() + # calculate actors loss # 1- temperature temperature = self.temperature() # 2- get actions (batch_size, action_dim) and log probs (batch_size,) actions, log_probs = self.actor(observations) # 3- get q-value predictions - with torch.no_grad(): - q_preds = self.critic_ensemble(observations, actions, return_type="mean") + with torch.inference_mode(): + q_preds = self.critic_forward(observations, actions, use_target=False) actor_loss = ( -(q_preds - temperature * log_probs).mean() - * ~batch["observation.state_is_pad"][0] - * ~batch["action_is_pad"] + # * ~batch["observation.state_is_pad"][0] + # * ~batch["action_is_pad"] ).mean() + # calculate temperature loss # 1- calculate entropy entropy = -log_probs.mean() - temperature_loss = self.temp(lhs=entropy, rhs=self.config.target_entropy) + temperature_loss = self.temperature( + lhs=entropy, + rhs=self.config.target_entropy + ) loss = critics_loss + actor_loss + temperature_loss return { - "critics_loss": critics_loss.item(), - "actor_loss": actor_loss.item(), - "temperature_loss": temperature_loss.item(), - "temperature": temperature.item(), - "entropy": entropy.item(), - "loss": loss, - } - + "critics_loss": critics_loss.item(), + "actor_loss": actor_loss.item(), + "temperature_loss": temperature_loss.item(), + "temperature": temperature.item(), + "entropy": entropy.item(), + "loss": loss, + } + def update(self): - self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight) # TODO: implement UTD update # First update only critics for utd_ratio-1 times - # for critic_step in range(self.config.utd_ratio - 1): - # only update critic and critic target + #for critic_step in range(self.config.utd_ratio - 1): + # only update critic and critic target # Then update critic, critic target, actor and temperature - - # for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): - # target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight) - - + """Update target networks with exponential moving average""" + with torch.no_grad(): + for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False): + for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False): + target_param.data.copy_( + target_param.data * self.config.critic_target_update_weight + + param.data * (1.0 - self.config.critic_target_update_weight) + ) + class MLP(nn.Module): def __init__( self, + input_dim: int, hidden_dims: list[int], activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), activate_final: bool = False, @@ -236,46 +258,52 @@ class MLP(nn.Module): super().__init__() self.activate_final = activate_final layers = [] - - for i, size in enumerate(hidden_dims): - layers.append(nn.Linear(hidden_dims[i - 1] if i > 0 else hidden_dims[0], size)) - + + # First layer uses input_dim + layers.append(nn.Linear(input_dim, hidden_dims[0])) + + # Add activation after first layer + if dropout_rate is not None and dropout_rate > 0: + layers.append(nn.Dropout(p=dropout_rate)) + layers.append(nn.LayerNorm(hidden_dims[0])) + layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) + + # Rest of the layers + for i in range(1, len(hidden_dims)): + layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i])) + if i + 1 < len(hidden_dims) or activate_final: if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) - layers.append(nn.LayerNorm(size)) - layers.append( - activations if isinstance(activations, nn.Module) else getattr(nn, activations)() - ) - + layers.append(nn.LayerNorm(hidden_dims[i])) + layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) + self.net = nn.Sequential(*layers) - def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor: - # in training mode or not. TODO: find better way to do this - self.train(train) + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) - - + + class Critic(nn.Module): def __init__( self, encoder: Optional[nn.Module], network: nn.Module, init_final: Optional[float] = None, - device: str = "cuda", + device: str = "cuda" ): super().__init__() self.device = torch.device(device) self.encoder = encoder self.network = network self.init_final = init_final - + # Find the last Linear layer's output dimension for layer in reversed(network.net): if isinstance(layer, nn.Linear): out_features = layer.out_features break - + # Output layer if init_final is not None: self.output_layer = nn.Linear(out_features, 1) @@ -284,17 +312,22 @@ class Critic(nn.Module): else: self.output_layer = nn.Linear(out_features, 1) orthogonal_init()(self.output_layer.weight) - + self.to(self.device) - def forward(self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False) -> torch.Tensor: - self.train(train) - - observations = observations.to(self.device) + def forward( + self, + observations: dict[str, torch.Tensor], + actions: torch.Tensor, + ) -> torch.Tensor: + # Move each tensor in observations to device + observations = { + k: v.to(self.device) for k, v in observations.items() + } actions = actions.to(self.device) - + obs_enc = observations if self.encoder is None else self.encoder(observations) - + inputs = torch.cat([obs_enc, actions], dim=-1) x = self.network(inputs) value = self.output_layer(x) @@ -312,7 +345,7 @@ class Policy(nn.Module): fixed_std: Optional[torch.Tensor] = None, init_final: Optional[float] = None, use_tanh_squash: bool = False, - device: str = "cuda", + device: str = "cuda" ): super().__init__() self.device = torch.device(device) @@ -323,13 +356,13 @@ class Policy(nn.Module): self.log_std_max = log_std_max self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.use_tanh_squash = use_tanh_squash - + # Find the last Linear layer's output dimension for layer in reversed(network.net): if isinstance(layer, nn.Linear): out_features = layer.out_features break - + # Mean layer self.mean_layer = nn.Linear(out_features, action_dim) if init_final is not None: @@ -337,7 +370,7 @@ class Policy(nn.Module): nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) else: orthogonal_init()(self.mean_layer.weight) - + # Standard deviation layer or parameter if fixed_std is None: self.std_layer = nn.Linear(out_features, action_dim) @@ -346,20 +379,21 @@ class Policy(nn.Module): nn.init.uniform_(self.std_layer.bias, -init_final, init_final) else: orthogonal_init()(self.std_layer.weight) - + self.to(self.device) def forward( - self, + self, observations: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: + # Encode observations if encoder exists - obs_enc = observations if self.encoder is not None else self.encoder(observations) + obs_enc = observations if self.encoder is None else self.encoder(observations) # Get network outputs outputs = self.network(obs_enc) means = self.mean_layer(outputs) - + # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) @@ -367,25 +401,25 @@ class Policy(nn.Module): log_std = torch.tanh(log_std) log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) else: - stds = self.fixed_std.expand_as(means) - + log_std = self.fixed_std.expand_as(means) + # uses tahn activation function to squash the action to be in the range of [-1, 1] - normal = torch.distributions.Normal(means, stds) - x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) + normal = torch.distributions.Normal(means, torch.exp(log_std)) + x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) log_probs = normal.log_prob(x_t) if self.use_tanh_squash: actions = torch.tanh(x_t) log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) - log_probs = log_probs.sum(-1) # sum over action dim + log_probs = log_probs.sum(-1) # sum over action dim return actions, log_probs - + def get_features(self, observations: torch.Tensor) -> torch.Tensor: """Get encoded features from observations""" observations = observations.to(self.device) if self.encoder is not None: - with torch.no_grad(): - return self.encoder(observations, train=False) + with torch.inference_mode(): + return self.encoder(observations) return observations @@ -459,43 +493,56 @@ class SACObservationEncoder(nn.Module): feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) if "observation.state" in self.config.input_shapes: feat.append(self.state_enc_layers(obs_dict["observation.state"])) + # TODO(ke-wang): currently average over all features, concatenate all features maybe a better way return torch.stack(feat, dim=0).mean(0) + + @property + def output_dim(self) -> int: + """Returns the dimension of the encoder output""" + return self.config.latent_dim class LagrangeMultiplier(nn.Module): - def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"): + def __init__( + self, + init_value: float = 1.0, + constraint_shape: Sequence[int] = (), + device: str = "cuda" + ): super().__init__() self.device = torch.device(device) init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1) - + # Initialize the Lagrange multiplier as a parameter self.lagrange = nn.Parameter( torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device) ) - + self.to(self.device) - def forward(self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None) -> torch.Tensor: - # Get the multiplier value based on parameterization + def forward( + self, + lhs: Optional[torch.Tensor | float | int] = None, + rhs: Optional[torch.Tensor | float | int] = None + ) -> torch.Tensor: + # Get the multiplier value based on parameterization multiplier = torch.nn.functional.softplus(self.lagrange) - + # Return the raw multiplier if no constraint values provided if lhs is None: return multiplier - - # Move inputs to device - lhs = lhs.to(self.device) + + # Convert inputs to tensors and move to device + lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device) if rhs is not None: - rhs = rhs.to(self.device) - - # Use the multiplier to compute the Lagrange penalty - if rhs is None: + rhs = torch.tensor(rhs, device=self.device) if not isinstance(rhs, torch.Tensor) else rhs.to(self.device) + else: rhs = torch.zeros_like(lhs, device=self.device) - + diff = lhs - rhs - + assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}" - + return multiplier * diff @@ -508,7 +555,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}" return nn.ModuleList(critics).to(device) - # borrowed from tdmpc def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: """Helper to temporarily flatten extra dims at the start of the image tensor. @@ -516,7 +562,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens Args: fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return (B, *), where * is any number of dimensions. - image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and + image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and can be more than 1 dimensions, generally different from *. Returns: A return value from the callable reshaped to (**, *). @@ -526,4 +572,4 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens start_dims = image_tensor.shape[:-3] inp = torch.flatten(image_tensor, end_dim=-4) flat_out = fn(inp) - return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) + return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) \ No newline at end of file From 35de91ef2bed8d25ef6aa40e6ff8514a39666436 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 30 Dec 2024 13:47:28 +0000 Subject: [PATCH 026/112] added temporary fix for missing task_index key in online environment --- lerobot/common/policies/sac/configuration_sac.py | 1 + lerobot/scripts/train.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 5f676933..4ae6e5d4 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -50,6 +50,7 @@ class SACConfig: state_encoder_hidden_dim = 256 latent_dim = 128 target_entropy = None + backup_entropy = True critic_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index fbe7927d..a4eb3528 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -322,6 +322,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("make_dataset") offline_dataset = make_dataset(cfg) + # TODO (michel-aractingi): temporary fix to avoid datasets with task_index key that doesn't exist in online environment + # i.e., pusht + if "task_index" in offline_dataset.hf_dataset[0]: + offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(["task_index"]) + if isinstance(offline_dataset, MultiLeRobotDataset): logging.info( "Multiple datasets were provided. Applied the following index mapping to the provided datasets: " From c5bca1cf0f1055f898d58c46432b80d70b615cda Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Mon, 6 Jan 2025 17:34:00 +0700 Subject: [PATCH 027/112] [Port HIL_SERL] Final fixes for the Reward Classifier (#598) --- .../hilserl/classifier/modeling_classifier.py | 3 ++- lerobot/common/policies/sac/modeling_sac.py | 1 - lerobot/common/robot_devices/control_utils.py | 8 +++++-- .../configs/policy/hilserl_classifier.yaml | 1 - lerobot/scripts/control_robot.py | 2 +- lerobot/scripts/control_sim_robot.py | 23 ++++++++++++++++++- lerobot/scripts/train_hilserl_classifier.py | 17 ++++++++++---- poetry.lock | 2 +- pyproject.toml | 4 ++-- .../classifier/test_modelling_classifier.py | 9 +++++++- tests/test_train_hilserl_classifier.py | 8 +++---- 11 files changed, 59 insertions(+), 19 deletions(-) diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index 28b05744..d7bd42cd 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -4,7 +4,6 @@ from typing import Optional import torch from huggingface_hub import PyTorchModelHubMixin from torch import Tensor, nn -from transformers import AutoImageProcessor, AutoModel from .configuration_classifier import ClassifierConfig @@ -44,6 +43,8 @@ class Classifier( name = "classifier" def __init__(self, config: ClassifierConfig): + from transformers import AutoImageProcessor, AutoModel + super().__init__() self.config = config self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index bd77408e..62725ce1 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -333,7 +333,6 @@ class Critic(nn.Module): value = self.output_layer(x) return value.squeeze(-1) - class Policy(nn.Module): def __init__( self, diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 8a6bcfbd..ad6f5632 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -362,12 +362,16 @@ def sanity_check_dataset_name(repo_id, policy): def sanity_check_dataset_robot_compatibility( - dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool + dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None ) -> None: + features_from_robot = get_features_from_robot(robot, use_videos) + if extra_features is not None: + features_from_robot.update(extra_features) + fields = [ ("robot_type", dataset.meta.robot_type, robot.robot_type), ("fps", dataset.fps, fps), - ("features", dataset.features, get_features_from_robot(robot, use_videos)), + ("features", dataset.features, features_from_robot), ] mismatches = [] diff --git a/lerobot/configs/policy/hilserl_classifier.yaml b/lerobot/configs/policy/hilserl_classifier.yaml index be82bc4e..498c9983 100644 --- a/lerobot/configs/policy/hilserl_classifier.yaml +++ b/lerobot/configs/policy/hilserl_classifier.yaml @@ -39,7 +39,6 @@ policy: wandb: enable: false project: "classifier-training" - entity: "wandb_entity" job_name: "classifier_training_0" disable_artifact: false diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 45a6bd66..f45e6b48 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -246,7 +246,7 @@ def record( num_processes=num_image_writer_processes, num_threads=num_image_writer_threads_per_camera * len(robot.cameras), ) - sanity_check_dataset_robot_compatibility(dataset, robot, fps, video) + sanity_check_dataset_robot_compatibility(dataset, robot, fps, video, extra_features) else: # Create empty dataset or load existing saved episodes sanity_check_dataset_name(repo_id, policy) diff --git a/lerobot/scripts/control_sim_robot.py b/lerobot/scripts/control_sim_robot.py index 4fffa8c7..67bdfb85 100644 --- a/lerobot/scripts/control_sim_robot.py +++ b/lerobot/scripts/control_sim_robot.py @@ -183,8 +183,14 @@ def record( resume: bool = False, local_files_only: bool = False, run_compute_stats: bool = True, + assign_rewards: bool = False, ) -> LeRobotDataset: # Load pretrained policy + + extra_features = ( + {"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None + ) + policy = None if pretrained_policy_name_or_path is not None: policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) @@ -197,7 +203,7 @@ def record( raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.") # initialize listener before sim env - listener, events = init_keyboard_listener() + listener, events = init_keyboard_listener(assign_rewards=assign_rewards) # create sim env env = env() @@ -237,6 +243,7 @@ def record( } features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None} + features = {**features, **extra_features} # Create empty dataset or load existing saved episodes sanity_check_dataset_name(repo_id, policy) @@ -288,6 +295,13 @@ def record( "timestamp": env_timestamp, } + # Overwrite environment reward with manually assigned reward + if assign_rewards: + frame["next.reward"] = events["next.reward"] + + # Should success always be false to match what we do in control_utils? + frame["next.success"] = False + for key in image_keys: if not key.startswith("observation.image"): frame["observation.image." + key] = observation[key] @@ -472,6 +486,13 @@ if __name__ == "__main__": default=0, help="Resume recording on an existing dataset.", ) + parser_record.add_argument( + "--assign-rewards", + type=int, + default=0, + help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.", + ) + parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay.add_argument( "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index ea8336a9..22ff2957 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -45,7 +45,7 @@ from lerobot.common.utils.utils import ( ) -def get_model(cfg, logger): +def get_model(cfg, logger): # noqa I001 classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) model = Classifier(classifier_config) if cfg.resume: @@ -64,6 +64,12 @@ def create_balanced_sampler(dataset, cfg): return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True) +def support_amp(device: torch.device, cfg: DictConfig) -> bool: + # Check if the device supports AMP + # Here is an example of the issue that says that MPS doesn't support AMP properply + return cfg.training.use_amp and device.type in ("cuda", "cpu") + + def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg): # Single epoch training loop with AMP support and progress tracking model.train() @@ -77,7 +83,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, labels = batch[cfg.training.label_key].float().to(device) # Forward pass with optional AMP - with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext(): + with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(): outputs = model(images) loss = criterion(outputs.logits, labels) @@ -119,7 +125,10 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l samples = [] running_loss = 0 - with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext(): + with ( + torch.no_grad(), + torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(), + ): for batch in tqdm(val_loader, desc="Validation"): images = batch[cfg.training.image_key].to(device) labels = batch[cfg.training.label_key].float().to(device) @@ -170,7 +179,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l return accuracy, eval_info -@hydra.main(version_base="1.2", config_path="../configs", config_name="hilserl_classifier") +@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier") def train(cfg: DictConfig) -> None: # Main training pipeline with support for resuming training logging.info(OmegaConf.to_yaml(cfg)) diff --git a/poetry.lock b/poetry.lock index 919edd18..81462fe8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7720,4 +7720,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "b9d299916ced6af1d243f961a32b0a4aacbef18e0b95337a5224e8511f5d6dda" +content-hash = "44c74163e398e8ff16973957f69a47bb09b789e92ac4d8fb3ab268defab96427" diff --git a/pyproject.toml b/pyproject.toml index 738903bd..05ab921a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,8 +71,8 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true} pyserial = {version = ">=3.5", optional = true} jsonlines = ">=4.0.0" -transformers = {version = "^4.47.0", optional = true} -torchmetrics = {version = "^1.6.0", optional = true} +transformers = {version = ">=4.47.0", optional = true} +torchmetrics = {version = ">=1.6.0", optional = true} [tool.poetry.extras] diff --git a/tests/policies/hilserl/classifier/test_modelling_classifier.py b/tests/policies/hilserl/classifier/test_modelling_classifier.py index 014165eb..a3db4211 100644 --- a/tests/policies/hilserl/classifier/test_modelling_classifier.py +++ b/tests/policies/hilserl/classifier/test_modelling_classifier.py @@ -1,7 +1,6 @@ import torch from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( - Classifier, ClassifierConfig, ClassifierOutput, ) @@ -21,6 +20,8 @@ def test_classifier_output(): @require_package("transformers") def test_binary_classifier_with_default_params(): + from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + config = ClassifierConfig() classifier = Classifier(config) @@ -40,6 +41,8 @@ def test_binary_classifier_with_default_params(): @require_package("transformers") def test_multiclass_classifier(): + from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + num_classes = 5 config = ClassifierConfig(num_classes=num_classes) classifier = Classifier(config) @@ -60,6 +63,8 @@ def test_multiclass_classifier(): @require_package("transformers") def test_default_device(): + from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + config = ClassifierConfig() assert config.device == "cpu" @@ -70,6 +75,8 @@ def test_default_device(): @require_package("transformers") def test_explicit_device_setup(): + from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + config = ClassifierConfig(device="meta") assert config.device == "meta" diff --git a/tests/test_train_hilserl_classifier.py b/tests/test_train_hilserl_classifier.py index 66d8fbe4..c1d854ac 100644 --- a/tests/test_train_hilserl_classifier.py +++ b/tests/test_train_hilserl_classifier.py @@ -151,9 +151,9 @@ def test_validate(): @patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir") @patch("lerobot.scripts.train_hilserl_classifier.Logger") @patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset") -@patch("lerobot.scripts.train_hilserl_classifier.make_policy") +@patch("lerobot.scripts.train_hilserl_classifier.get_model") def test_resume_function( - mock_make_policy, + mock_get_model, mock_dataset, mock_logger, mock_get_last_pretrained_model_dir, @@ -168,7 +168,7 @@ def test_resume_function( with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"): cfg = compose( - config_name="reward_classifier", + config_name="hilserl_classifier", overrides=[ "device=cpu", "seed=42", @@ -211,7 +211,7 @@ def test_resume_function( # Instantiate the model and set make_policy to return it model = make_dummy_model() - mock_make_policy.return_value = model + mock_get_model.return_value = model # Call train train(cfg) From 3bb5ed5e91a2b78d9a8f5883171ce45da3c496ff Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 13 Jan 2025 13:57:49 +0100 Subject: [PATCH 028/112] Extend reward classifier for multiple camera views (#626) --- lerobot/common/logger.py | 2 +- .../classifier/configuration_classifier.py | 1 + .../hilserl/classifier/modeling_classifier.py | 16 ++- lerobot/common/robot_devices/control_utils.py | 9 ++ .../configs/policy/hilserl_classifier.yaml | 9 +- lerobot/scripts/control_robot.py | 13 ++ lerobot/scripts/eval_on_robot.py | 123 +++++++++++++----- lerobot/scripts/train_hilserl_classifier.py | 7 +- tests/test_train_hilserl_classifier.py | 61 ++++++++- 9 files changed, 192 insertions(+), 49 deletions(-) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index dec8b465..4015492d 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -25,13 +25,13 @@ from glob import glob from pathlib import Path import torch +import wandb from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from omegaconf import DictConfig, OmegaConf from termcolor import colored from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler -import wandb from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import get_global_random_state, set_global_random_state diff --git a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py index f0b9352f..de3742ec 100644 --- a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py @@ -13,6 +13,7 @@ class ClassifierConfig: model_name: str = "microsoft/resnet-50" device: str = "cpu" model_type: str = "cnn" # "transformer" or "cnn" + num_cameras: int = 2 def save_pretrained(self, save_dir): """Save config to json file.""" diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index d7bd42cd..4a022335 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -97,7 +97,7 @@ class Classifier( raise ValueError("Unsupported transformer architecture since hidden_size is not found") self.classifier_head = nn.Sequential( - nn.Linear(input_dim, self.config.hidden_dim), + nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim), nn.Dropout(self.config.dropout_rate), nn.LayerNorm(self.config.hidden_dim), nn.ReLU(), @@ -130,11 +130,11 @@ class Classifier( return outputs.pooler_output return outputs.last_hidden_state[:, 0, :] - def forward(self, x: torch.Tensor) -> ClassifierOutput: + def forward(self, xs: torch.Tensor) -> ClassifierOutput: """Forward pass of the classifier.""" # For training, we expect input to be a tensor directly from LeRobotDataset - encoder_output = self._get_encoder_output(x) - logits = self.classifier_head(encoder_output) + encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs]) + logits = self.classifier_head(encoder_outputs) if self.config.num_classes == 2: logits = logits.squeeze(-1) @@ -142,4 +142,10 @@ class Classifier( else: probabilities = torch.softmax(logits, dim=-1) - return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_output) + return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs) + + def predict_reward(self, x): + if self.config.num_classes == 2: + return (self.forward(x).probabilities > 0.5).float() + else: + return torch.argmax(self.forward(x).probabilities, dim=1) diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index ad6f5632..10cb9f5c 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -11,6 +11,7 @@ from copy import copy from functools import cache import cv2 +import numpy as np import torch import tqdm from deepdiff import DeepDiff @@ -332,6 +333,14 @@ def reset_environment(robot, events, reset_time_s): break +def reset_follower_position(robot: Robot, target_position): + current_position = robot.follower_arms["main"].read("Present_Position") + trajectory = torch.from_numpy(np.linspace(current_position, target_position, 30)) # NOTE: 30 is just an aribtrary number + for pose in trajectory: + robot.send_action(pose) + busy_wait(0.015) + + def stop_recording(robot, listener, display_cameras): robot.disconnect() diff --git a/lerobot/configs/policy/hilserl_classifier.yaml b/lerobot/configs/policy/hilserl_classifier.yaml index 498c9983..f8137b69 100644 --- a/lerobot/configs/policy/hilserl_classifier.yaml +++ b/lerobot/configs/policy/hilserl_classifier.yaml @@ -4,7 +4,7 @@ defaults: - _self_ seed: 13 -dataset_repo_id: "dataset_repo_id" +dataset_repo_id: aractingi/pick_place_lego_cube_1 train_split_proportion: 0.8 # Required by logger @@ -24,7 +24,7 @@ training: eval_freq: 1 # How often to run validation (in epochs) save_freq: 1 # How often to save checkpoints (in epochs) save_checkpoint: true - image_key: "observation.images.phone" + image_keys: ["observation.images.top", "observation.images.wrist"] label_key: "next.reward" eval: @@ -32,9 +32,10 @@ eval: num_samples_to_log: 30 # Number of validation samples to log in the table policy: - name: "hilserl/classifier" + name: "hilserl/classifier/pick_place_lego_cube_1" model_name: "facebook/convnext-base-224" model_type: "cnn" + num_cameras: 2 # Has to be len(training.image_keys) wandb: enable: false @@ -44,4 +45,4 @@ wandb: device: "mps" resume: false -output_dir: "output" +output_dir: "outputs/classifier" diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index f45e6b48..8187e8a3 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -109,6 +109,7 @@ from lerobot.common.robot_devices.control_utils import ( log_control_info, record_episode, reset_environment, + reset_follower_position, sanity_check_dataset_name, sanity_check_dataset_robot_compatibility, stop_recording, @@ -205,6 +206,7 @@ def record( num_image_writer_threads_per_camera: int = 4, display_cameras: bool = True, play_sounds: bool = True, + reset_follower: bool = False, resume: bool = False, # TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument local_files_only: bool = False, @@ -265,6 +267,9 @@ def record( robot.connect() listener, events = init_keyboard_listener(assign_rewards=assign_rewards) + if reset_follower: + initial_position = robot.follower_arms["main"].read("Present_Position") + # Execute a few seconds without recording to: # 1. teleoperate the robot to move it in starting position if no policy provided, # 2. give times to the robot devices to connect and start synchronizing, @@ -307,6 +312,8 @@ def record( (dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"] ): log_say("Reset the environment", play_sounds) + if reset_follower: + reset_follower_position(robot, initial_position) reset_environment(robot, events, reset_time_s) if events["rerecord_episode"]: @@ -527,6 +534,12 @@ if __name__ == "__main__": default=0, help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.", ) + parser_record.add_argument( + "--reset-follower", + type=int, + default=0, + help="Resets the follower to the initial position during while reseting the evironment, this is to avoid having the follower start at an awkward position in the next episode", + ) parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay.add_argument( diff --git a/lerobot/scripts/eval_on_robot.py b/lerobot/scripts/eval_on_robot.py index 92daa860..842c1a28 100644 --- a/lerobot/scripts/eval_on_robot.py +++ b/lerobot/scripts/eval_on_robot.py @@ -23,6 +23,15 @@ python lerobot/scripts/eval_on_robot.py \ eval.n_episodes=10 ``` +Test reward classifier with teleoperation (you need to press space to take over) +``` +python lerobot/scripts/eval_on_robot.py \ + --robot-path lerobot/configs/robot/so100.yaml \ + --reward-classifier-pretrained-path outputs/classifier/checkpoints/best/pretrained_model \ + --reward-classifier-config-file lerobot/configs/policy/hilserl_classifier.yaml \ + --display-cameras 1 +``` + **NOTE** (michel-aractingi): This script is incomplete and it is being prepared for running training on the real robot. """ @@ -30,14 +39,14 @@ for running training on the real robot. import argparse import logging import time -from copy import deepcopy +import cv2 import numpy as np import torch from tqdm import trange from lerobot.common.policies.policy_protocol import Policy -from lerobot.common.robot_devices.control_utils import busy_wait, is_headless +from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position from lerobot.common.robot_devices.robots.factory import Robot, make_robot from lerobot.common.utils.utils import ( init_hydra_config, @@ -46,7 +55,33 @@ from lerobot.common.utils.utils import ( ) -def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict: +def get_classifier(pretrained_path, config_path): + if pretrained_path is None or config_path is None: + return + + from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg + from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig + from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + + cfg = init_hydra_config(config_path) + + classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) + classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths + model = Classifier(classifier_config) + model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict()) + model = model.to("mps") + return model + + +def rollout( + robot: Robot, + policy: Policy, + reward_classifier, + fps: int, + control_time_s: float = 20, + use_amp: bool = True, + display_cameras: bool = False, +) -> dict: """Run a batched policy rollout on the real robot. The return dictionary contains: @@ -70,6 +105,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, Returns: The dictionary described above. """ + # TODO (michel-aractingi): Infer the device from policy parameters when policy is added # assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module." # device = get_device_from_parameters(policy) @@ -79,25 +115,21 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, # Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready. # policy.reset() - # Get observation from real robot + # NOTE: sorting to make sure the key sequence is the same during training and testing. observation = robot.capture_observation() + image_keys = [key for key in observation if "image" in key] + image_keys.sort() - # Calculate reward. TODO (michel-aractingi) - # in HIL-SERL it will be with a reward classifier - reward = calculate_reward(observation) - all_observations = [] all_actions = [] all_rewards = [] all_successes = [] start_episode_t = time.perf_counter() + init_pos = robot.follower_arms["main"].read("Present_Position") timestamp = 0.0 while timestamp < control_time_s: start_loop_t = time.perf_counter() - all_observations.append(deepcopy(observation)) - # observation = {key: observation[key].to(device, non_blocking=True) for key in observation} - # Apply the next action. while events["pause_policy"] and not events["human_intervention_step"]: busy_wait(0.5) @@ -109,18 +141,26 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, else: # explore with policy with torch.inference_mode(): + # TODO (michel-aractingi) replace this part with policy (predict_action) action = robot.follower_arms["main"].read("Present_Position") action = torch.from_numpy(action) robot.send_action(action) # action = predict_action(observation, policy, device, use_amp) observation = robot.capture_observation() - # Calculate reward - # in HIL-SERL it will be with a reward classifier - reward = calculate_reward(observation) + images = [] + for key in image_keys: + if display_cameras: + cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) + images.append(observation[key].to("mps")) + + reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0 + all_rewards.append(reward) + + # print("REWARD : ", reward) all_actions.append(action) - all_rewards.append(torch.from_numpy(reward)) all_successes.append(torch.tensor([False])) dt_s = time.perf_counter() - start_loop_t @@ -131,7 +171,8 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, events["human_intervention_step"] = False events["pause_policy"] = False break - all_observations.append(deepcopy(observation)) + + reset_follower_position(robot, target_position=init_pos) dones = torch.tensor([False] * len(all_actions)) dones[-1] = True @@ -142,10 +183,6 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, "next.success": torch.stack(all_successes, dim=1), "done": dones, } - stacked_observations = {} - for key in all_observations[0]: - stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) - ret["observation"] = stacked_observations listener.stop() @@ -159,6 +196,9 @@ def eval_policy( n_episodes: int, control_time_s: int = 20, use_amp: bool = True, + display_cameras: bool = False, + reward_classifier_pretrained_path: str | None = None, + reward_classifier_config_file: str | None = None, ) -> dict: """ Args: @@ -179,8 +219,12 @@ def eval_policy( start_eval = time.perf_counter() progbar = trange(n_episodes, desc="Evaluating policy on real robot") - for _batch_idx in progbar: - rollout_data = rollout(robot, policy, fps, control_time_s, use_amp) + reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file) + + for _ in progbar: + rollout_data = rollout( + robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras + ) rollouts.append(rollout_data) sum_rewards.append(sum(rollout_data["next.reward"])) @@ -219,15 +263,6 @@ def eval_policy( return info -def calculate_reward(observation): - """ - Method to calculate reward function in some way. - In HIL-SERL this is done through defining a reward classifier - """ - # reward = reward_classifier(observation) - return np.array([0.0]) - - def init_keyboard_listener(): # Allow to exit early while recording an episode or resetting the environment, # by tapping the right arrow key '->'. This might require a sudo permission @@ -324,6 +359,21 @@ if __name__ == "__main__": "outputs/eval/{timestamp}_{env_name}_{policy_name}" ), ) + parser.add_argument( + "--display-cameras", help=("Whether to display the camera feed while the rollout is happening") + ) + parser.add_argument( + "--reward-classifier-pretrained-path", + type=str, + default=None, + help="Path to the pretrained classifier weights.", + ) + parser.add_argument( + "--reward-classifier-config-file", + type=str, + default=None, + help="Path to a yaml config file that is necessary to build the reward classifier model.", + ) args = parser.parse_args() @@ -332,4 +382,13 @@ if __name__ == "__main__": if not robot.is_connected: robot.connect() - eval_policy(robot, None, fps=40, n_episodes=2, control_time_s=100) + eval_policy( + robot, + None, + fps=40, + n_episodes=2, + control_time_s=100, + display_cameras=args.display_cameras, + reward_classifier_config_file=args.reward_classifier_config_file, + reward_classifier_pretrained_path=args.reward_classifier_pretrained_path, + ) diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 22ff2957..458e3ff1 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -22,6 +22,7 @@ from pprint import pformat import hydra import torch import torch.nn as nn +import wandb from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from termcolor import colored @@ -30,7 +31,6 @@ from torch.cuda.amp import GradScaler from torch.utils.data import DataLoader, WeightedRandomSampler, random_split from tqdm import tqdm -import wandb from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.logger import Logger @@ -79,7 +79,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, pbar = tqdm(train_loader, desc="Training") for batch_idx, batch in enumerate(pbar): start_time = time.perf_counter() - images = batch[cfg.training.image_key].to(device) + images = [batch[img_key].to(device) for img_key in cfg.training.image_keys] labels = batch[cfg.training.label_key].float().to(device) # Forward pass with optional AMP @@ -130,7 +130,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(), ): for batch in tqdm(val_loader, desc="Validation"): - images = batch[cfg.training.image_key].to(device) + images = [batch[img_key].to(device) for img_key in cfg.training.image_keys] labels = batch[cfg.training.label_key].float().to(device) outputs = model(images) @@ -163,6 +163,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l accuracy = 100 * correct / total avg_loss = running_loss / len(val_loader) + print(f"Average validation loss {avg_loss}, and accuracy {accuracy}") eval_info = { "loss": avg_loss, diff --git a/tests/test_train_hilserl_classifier.py b/tests/test_train_hilserl_classifier.py index c1d854ac..8c1ad453 100644 --- a/tests/test_train_hilserl_classifier.py +++ b/tests/test_train_hilserl_classifier.py @@ -33,7 +33,9 @@ class MockDataset(Dataset): def make_dummy_model(): - model_config = ClassifierConfig(num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel") + model_config = ClassifierConfig( + num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=1 + ) model = Classifier(config=model_config) return model @@ -88,7 +90,7 @@ def test_train_epoch(): logger = MagicMock() step = 0 cfg = MagicMock() - cfg.training.image_key = "image" + cfg.training.image_keys = ["image"] cfg.training.label_key = "label" cfg.training.use_amp = False @@ -130,7 +132,7 @@ def test_validate(): device = torch.device("cpu") logger = MagicMock() cfg = MagicMock() - cfg.training.image_key = "image" + cfg.training.image_keys = ["image"] cfg.training.label_key = "label" cfg.training.use_amp = False @@ -145,6 +147,57 @@ def test_validate(): assert isinstance(eval_info, dict) +def test_train_epoch_multiple_cameras(): + model_config = ClassifierConfig( + num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=2 + ) + model = Classifier(config=model_config) + + # Mock components + model.train = MagicMock() + + train_loader = [ + { + "image_1": torch.rand(2, 3, 224, 224), + "image_2": torch.rand(2, 3, 224, 224), + "label": torch.tensor([0.0, 1.0]), + } + ] + + criterion = nn.BCEWithLogitsLoss() + optimizer = MagicMock() + grad_scaler = MagicMock() + device = torch.device("cpu") + logger = MagicMock() + step = 0 + cfg = MagicMock() + cfg.training.image_keys = ["image_1", "image_2"] + cfg.training.label_key = "label" + cfg.training.use_amp = False + + # Call the function under test + train_epoch( + model, + train_loader, + criterion, + optimizer, + grad_scaler, + device, + logger, + step, + cfg, + ) + + # Check that model.train() was called + model.train.assert_called_once() + + # Check that optimizer.zero_grad() was called + optimizer.zero_grad.assert_called() + + # Check that logger.log_dict was called + logger.log_dict.assert_called() + + @pytest.mark.parametrize("resume", [True, False]) @patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config") @patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir") @@ -179,7 +232,7 @@ def test_resume_function( "train_split_proportion=0.8", "training.num_workers=0", "training.batch_size=2", - "training.image_key=image", + "training.image_keys=[image]", "training.label_key=label", "training.use_amp=False", "training.num_epochs=1", From 0a4e9e25d0d6008cb364aaceb80c759236af376c Mon Sep 17 00:00:00 2001 From: Mishig Date: Fri, 20 Dec 2024 16:26:23 +0100 Subject: [PATCH 029/112] [vizualizer] for LeRobodDataset V2 (#576) --- lerobot/common/datasets/utils.py | 57 +++ lerobot/scripts/visualize_dataset_html.py | 327 ++++++++++++++---- .../templates/visualize_dataset_homepage.html | 68 ++++ .../templates/visualize_dataset_template.html | 80 +++-- tests/test_visualize_dataset_html.py | 30 -- 5 files changed, 428 insertions(+), 134 deletions(-) create mode 100644 lerobot/templates/visualize_dataset_homepage.html delete mode 100644 tests/test_visualize_dataset_html.py diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index af5b03cc..1490adda 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -17,9 +17,11 @@ import importlib.resources import json import logging import textwrap +from collections.abc import Iterator from itertools import accumulate from pathlib import Path from pprint import pformat +from types import SimpleNamespace from typing import Any import datasets @@ -502,3 +504,58 @@ def create_lerobot_dataset_card( template_path=str(card_template_path), **kwargs, ) + + +class IterableNamespace(SimpleNamespace): + """ + A namespace object that supports both dictionary-like iteration and dot notation access. + Automatically converts nested dictionaries into IterableNamespaces. + + This class extends SimpleNamespace to provide: + - Dictionary-style iteration over keys + - Access to items via both dot notation (obj.key) and brackets (obj["key"]) + - Dictionary-like methods: items(), keys(), values() + - Recursive conversion of nested dictionaries + + Args: + dictionary: Optional dictionary to initialize the namespace + **kwargs: Additional keyword arguments passed to SimpleNamespace + + Examples: + >>> data = {"name": "Alice", "details": {"age": 25}} + >>> ns = IterableNamespace(data) + >>> ns.name + 'Alice' + >>> ns.details.age + 25 + >>> list(ns.keys()) + ['name', 'details'] + >>> for key, value in ns.items(): + ... print(f"{key}: {value}") + name: Alice + details: IterableNamespace(age=25) + """ + + def __init__(self, dictionary: dict[str, Any] = None, **kwargs): + super().__init__(**kwargs) + if dictionary is not None: + for key, value in dictionary.items(): + if isinstance(value, dict): + setattr(self, key, IterableNamespace(value)) + else: + setattr(self, key, value) + + def __iter__(self) -> Iterator[str]: + return iter(vars(self)) + + def __getitem__(self, key: str) -> Any: + return vars(self)[key] + + def items(self): + return vars(self).items() + + def values(self): + return vars(self).values() + + def keys(self): + return vars(self).keys() diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index 2c81fbfc..ec6eca22 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -53,20 +53,29 @@ python lerobot/scripts/visualize_dataset_html.py \ """ import argparse +import csv +import json import logging +import re import shutil +import tempfile +from io import StringIO from pathlib import Path -import tqdm -from flask import Flask, redirect, render_template, url_for +import numpy as np +import pandas as pd +import requests +from flask import Flask, redirect, render_template, request, url_for +from lerobot import available_datasets from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import IterableNamespace from lerobot.common.utils.utils import init_logging def run_server( - dataset: LeRobotDataset, - episodes: list[int], + dataset: LeRobotDataset | IterableNamespace | None, + episodes: list[int] | None, host: str, port: str, static_folder: Path, @@ -76,10 +85,50 @@ def run_server( app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache @app.route("/") - def index(): - # home page redirects to the first episode page - [dataset_namespace, dataset_name] = dataset.repo_id.split("/") - first_episode_id = episodes[0] + def hommepage(dataset=dataset): + if dataset: + dataset_namespace, dataset_name = dataset.repo_id.split("/") + return redirect( + url_for( + "show_episode", + dataset_namespace=dataset_namespace, + dataset_name=dataset_name, + episode_id=0, + ) + ) + + dataset_param, episode_param = None, None + all_params = request.args + if "dataset" in all_params: + dataset_param = all_params["dataset"] + if "episode" in all_params: + episode_param = int(all_params["episode"]) + + if dataset_param: + dataset_namespace, dataset_name = dataset_param.split("/") + return redirect( + url_for( + "show_episode", + dataset_namespace=dataset_namespace, + dataset_name=dataset_name, + episode_id=episode_param if episode_param is not None else 0, + ) + ) + + featured_datasets = [ + "lerobot/aloha_static_cups_open", + "lerobot/columbia_cairlab_pusht_real", + "lerobot/taco_play", + ] + return render_template( + "visualize_dataset_homepage.html", + featured_datasets=featured_datasets, + lerobot_datasets=available_datasets, + ) + + @app.route("//") + def show_first_episode(dataset_namespace, dataset_name): + first_episode_id = 0 return redirect( url_for( "show_episode", @@ -90,30 +139,85 @@ def run_server( ) @app.route("///episode_") - def show_episode(dataset_namespace, dataset_name, episode_id): + def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes): + repo_id = f"{dataset_namespace}/{dataset_name}" + try: + if dataset is None: + dataset = get_dataset_info(repo_id) + except FileNotFoundError: + return ( + "Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461", + 400, + ) + dataset_version = ( + dataset.meta._version if isinstance(dataset, LeRobotDataset) else dataset.codebase_version + ) + match = re.search(r"v(\d+)\.", dataset_version) + if match: + major_version = int(match.group(1)) + if major_version < 2: + return "Make sure to convert your LeRobotDataset to v2 & above." + + episode_data_csv_str, columns = get_episode_data(dataset, episode_id) dataset_info = { - "repo_id": dataset.repo_id, - "num_samples": dataset.num_frames, - "num_episodes": dataset.num_episodes, + "repo_id": f"{dataset_namespace}/{dataset_name}", + "num_samples": dataset.num_frames + if isinstance(dataset, LeRobotDataset) + else dataset.total_frames, + "num_episodes": dataset.num_episodes + if isinstance(dataset, LeRobotDataset) + else dataset.total_episodes, "fps": dataset.fps, } - video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys] - tasks = dataset.meta.episodes[episode_id]["tasks"] - videos_info = [ - {"url": url_for("static", filename=video_path), "filename": video_path.name} - for video_path in video_paths - ] + if isinstance(dataset, LeRobotDataset): + video_paths = [ + dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys + ] + videos_info = [ + {"url": url_for("static", filename=video_path), "filename": video_path.parent.name} + for video_path in video_paths + ] + tasks = dataset.meta.episodes[0]["tasks"] + else: + video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"] + videos_info = [ + { + "url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + + dataset.video_path.format( + episode_chunk=int(episode_id) // dataset.chunks_size, + video_key=video_key, + episode_index=episode_id, + ), + "filename": video_key, + } + for video_key in video_keys + ] + + response = requests.get( + f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl" + ) + response.raise_for_status() + # Split into lines and parse each line as JSON + tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()] + + filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id] + tasks = filtered_tasks_jsonl[0]["tasks"] + videos_info[0]["language_instruction"] = tasks - ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id)) + if episodes is None: + episodes = list( + range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes) + ) + return render_template( "visualize_dataset_template.html", episode_id=episode_id, episodes=episodes, dataset_info=dataset_info, videos_info=videos_info, - ep_csv_url=ep_csv_url, - has_policy=False, + episode_data_csv_str=episode_data_csv_str, + columns=columns, ) app.run(host=host, port=port) @@ -124,46 +228,84 @@ def get_ep_csv_fname(episode_id: int): return ep_csv_fname -def write_episode_data_csv(output_dir, file_name, episode_index, dataset): - """Write a csv file containg timeseries data of an episode (e.g. state and action). +def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index): + """Get a csv str containing timeseries data of an episode (e.g. state and action). This file will be loaded by Dygraph javascript to plot data in real time.""" - from_idx = dataset.episode_data_index["from"][episode_index] - to_idx = dataset.episode_data_index["to"][episode_index] - + columns = [] has_state = "observation.state" in dataset.features has_action = "action" in dataset.features # init header of csv with state and action names header = ["timestamp"] if has_state: - dim_state = dataset.meta.shapes["observation.state"][0] + dim_state = ( + dataset.meta.shapes["observation.state"][0] + if isinstance(dataset, LeRobotDataset) + else dataset.features["observation.state"].shape[0] + ) header += [f"state_{i}" for i in range(dim_state)] + column_names = dataset.features["observation.state"]["names"] + while not isinstance(column_names, list): + column_names = list(column_names.values())[0] + columns.append({"key": "state", "value": column_names}) if has_action: - dim_action = dataset.meta.shapes["action"][0] + dim_action = ( + dataset.meta.shapes["action"][0] + if isinstance(dataset, LeRobotDataset) + else dataset.features.action.shape[0] + ) header += [f"action_{i}" for i in range(dim_action)] + column_names = dataset.features["action"]["names"] + while not isinstance(column_names, list): + column_names = list(column_names.values())[0] + columns.append({"key": "action", "value": column_names}) - columns = ["timestamp"] - if has_state: - columns += ["observation.state"] - if has_action: - columns += ["action"] - - rows = [] - data = dataset.hf_dataset.select_columns(columns) - for i in range(from_idx, to_idx): - row = [data[i]["timestamp"].item()] + if isinstance(dataset, LeRobotDataset): + from_idx = dataset.episode_data_index["from"][episode_index] + to_idx = dataset.episode_data_index["to"][episode_index] + selected_columns = ["timestamp"] if has_state: - row += data[i]["observation.state"].tolist() + selected_columns += ["observation.state"] if has_action: - row += data[i]["action"].tolist() - rows.append(row) + selected_columns += ["action"] + data = ( + dataset.hf_dataset.select(range(from_idx, to_idx)) + .select_columns(selected_columns) + .with_format("numpy") + ) + rows = np.hstack( + (np.expand_dims(data["timestamp"], axis=1), *[data[col] for col in selected_columns[1:]]) + ).tolist() + else: + repo_id = dataset.repo_id + selected_columns = ["timestamp"] + if "observation.state" in dataset.features: + selected_columns.append("observation.state") + if "action" in dataset.features: + selected_columns.append("action") - output_dir.mkdir(parents=True, exist_ok=True) - with open(output_dir / file_name, "w") as f: - f.write(",".join(header) + "\n") - for row in rows: - row_str = [str(col) for col in row] - f.write(",".join(row_str) + "\n") + url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format( + episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index + ) + df = pd.read_parquet(url) + data = df[selected_columns] # Select specific columns + rows = np.hstack( + ( + np.expand_dims(data["timestamp"], axis=1), + *[np.vstack(data[col]) for col in selected_columns[1:]], + ) + ).tolist() + + # Convert data to CSV string + csv_buffer = StringIO() + csv_writer = csv.writer(csv_buffer) + # Write header + csv_writer.writerow(header) + # Write data rows + csv_writer.writerows(rows) + csv_string = csv_buffer.getvalue() + + return csv_string, columns def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]: @@ -175,9 +317,31 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str] ] +def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]: + # check if the dataset has language instructions + if "language_instruction" not in dataset.features: + return None + + # get first frame index + first_frame_idx = dataset.episode_data_index["from"][ep_index].item() + + language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"] + # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored + # with the tf.tensor appearing in the string + return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)") + + +def get_dataset_info(repo_id: str) -> IterableNamespace: + response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json") + response.raise_for_status() # Raises an HTTPError for bad responses + dataset_info = response.json() + dataset_info["repo_id"] = repo_id + return IterableNamespace(dataset_info) + + def visualize_dataset_html( - dataset: LeRobotDataset, - episodes: list[int] = None, + dataset: LeRobotDataset | None, + episodes: list[int] | None = None, output_dir: Path | None = None, serve: bool = True, host: str = "127.0.0.1", @@ -186,11 +350,11 @@ def visualize_dataset_html( ) -> Path | None: init_logging() - if len(dataset.meta.image_keys) > 0: - raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.") + template_dir = Path(__file__).resolve().parent.parent / "templates" if output_dir is None: - output_dir = f"outputs/visualize_dataset_html/{dataset.repo_id}" + # Create a temporary directory that will be automatically cleaned up + output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_") output_dir = Path(output_dir) if output_dir.exists(): @@ -201,28 +365,33 @@ def visualize_dataset_html( output_dir.mkdir(parents=True, exist_ok=True) - # Create a simlink from the dataset video folder containg mp4 files to the output directory - # so that the http server can get access to the mp4 files. static_dir = output_dir / "static" static_dir.mkdir(parents=True, exist_ok=True) - ln_videos_dir = static_dir / "videos" - if not ln_videos_dir.exists(): - ln_videos_dir.symlink_to((dataset.root / "videos").resolve()) - template_dir = Path(__file__).resolve().parent.parent / "templates" + if dataset is None: + if serve: + run_server( + dataset=None, + episodes=None, + host=host, + port=port, + static_folder=static_dir, + template_folder=template_dir, + ) + else: + image_keys = dataset.meta.image_keys if isinstance(dataset, LeRobotDataset) else [] + if len(image_keys) > 0: + raise NotImplementedError(f"Image keys ({image_keys=}) are currently not supported.") - if episodes is None: - episodes = list(range(dataset.num_episodes)) + # Create a simlink from the dataset video folder containg mp4 files to the output directory + # so that the http server can get access to the mp4 files. + if isinstance(dataset, LeRobotDataset): + ln_videos_dir = static_dir / "videos" + if not ln_videos_dir.exists(): + ln_videos_dir.symlink_to((dataset.root / "videos").resolve()) - logging.info("Writing CSV files") - for episode_index in tqdm.tqdm(episodes): - # write states and actions in a csv (it can be slow for big datasets) - ep_csv_fname = get_ep_csv_fname(episode_index) - # TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors? - write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset) - - if serve: - run_server(dataset, episodes, host, port, static_dir, template_dir) + if serve: + run_server(dataset, episodes, host, port, static_dir, template_dir) def main(): @@ -231,7 +400,7 @@ def main(): parser.add_argument( "--repo-id", type=str, - required=True, + default=None, help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).", ) parser.add_argument( @@ -246,6 +415,12 @@ def main(): default=None, help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.", ) + parser.add_argument( + "--load-from-hf-hub", + type=int, + default=0, + help="Load videos and parquet files from HF Hub rather than local system.", + ) parser.add_argument( "--episodes", type=int, @@ -287,11 +462,19 @@ def main(): args = parser.parse_args() kwargs = vars(args) repo_id = kwargs.pop("repo_id") + load_from_hf_hub = kwargs.pop("load_from_hf_hub") root = kwargs.pop("root") local_files_only = kwargs.pop("local_files_only") - dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only) - visualize_dataset_html(dataset, **kwargs) + dataset = None + if repo_id: + dataset = ( + LeRobotDataset(repo_id, root=root, local_files_only=local_files_only) + if not load_from_hf_hub + else get_dataset_info(repo_id) + ) + + visualize_dataset_html(dataset, **vars(args)) if __name__ == "__main__": diff --git a/lerobot/templates/visualize_dataset_homepage.html b/lerobot/templates/visualize_dataset_homepage.html new file mode 100644 index 00000000..adff07be --- /dev/null +++ b/lerobot/templates/visualize_dataset_homepage.html @@ -0,0 +1,68 @@ + + + + + + Interactive Video Background Page + + + + +
+ +
+
+
+
+

LeRobot Dataset Visualizer

+ + create & train your own robots + +

+
+

Example Datasets:

+
    + {% for dataset in featured_datasets %} +
  • {{ dataset }}
  • + {% endfor %} +
+
+
+
+ + +
+ +
+ More example datasets +
    + {% for dataset in lerobot_datasets %} +
  • {{ dataset }}
  • + {% endfor %} +
+
+
+ + \ No newline at end of file diff --git a/lerobot/templates/visualize_dataset_template.html b/lerobot/templates/visualize_dataset_template.html index 0fa1e713..12d6e991 100644 --- a/lerobot/templates/visualize_dataset_template.html +++ b/lerobot/templates/visualize_dataset_template.html @@ -31,11 +31,16 @@ }">
-

{{ dataset_info.repo_id }}

+ + +

{{ dataset_info.repo_id }}

+
  • - Number of samples/frames: {{ dataset_info.num_frames }} + Number of samples/frames: {{ dataset_info.num_samples }}
  • Number of episodes: {{ dataset_info.num_episodes }} @@ -93,10 +98,10 @@
-
+
{% for video_info in videos_info %} -
-

{{ video_info.filename }}

+
+

{{ video_info.filename }}