Reward classifier and training (#528)
Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai> Co-authored-by: resolver101757 <kelster101757@hotmail.com> Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com> Co-authored-by: Remi <re.cadene@gmail.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
parent
b5f1ea3140
commit
67ac81d728
|
@ -0,0 +1,83 @@
|
|||
# Training a HIL-SERL Reward Classifier with LeRobot
|
||||
|
||||
This tutorial provides step-by-step instructions for training a reward classifier using LeRobot.
|
||||
|
||||
---
|
||||
|
||||
## Training Script Overview
|
||||
|
||||
LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow:
|
||||
|
||||
1. **Configuration Loading**
|
||||
The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.)
|
||||
|
||||
2. **Dataset Initialization**
|
||||
It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling.
|
||||
|
||||
3. **Classifier Initialization**
|
||||
A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either:
|
||||
- A single probability (binary classification), or
|
||||
- Logits (multi-class classification).
|
||||
|
||||
4. **Training Loop Execution**
|
||||
The script performs:
|
||||
- Forward and backward passes,
|
||||
- Optimization steps,
|
||||
- Periodic logging, evaluation, and checkpoint saving.
|
||||
|
||||
---
|
||||
|
||||
## Configuring with Hydra
|
||||
|
||||
For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file.
|
||||
|
||||
### Config File Setup
|
||||
|
||||
The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment:
|
||||
|
||||
Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards:
|
||||
|
||||
```yaml
|
||||
# Example: lerobot/configs/policy/reward_classifier.yaml
|
||||
dataset_repo_id: "my_dataset_repo_id"
|
||||
## Typical logs and metrics
|
||||
```
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
|
||||
|
||||
After that, you will see training log like this one:
|
||||
|
||||
```
|
||||
[2024-11-29 18:26:36,999][root][INFO] -
|
||||
Epoch 5/5
|
||||
Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%]
|
||||
```
|
||||
|
||||
or evaluation log like:
|
||||
|
||||
```
|
||||
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s]
|
||||
```
|
||||
|
||||
### Metrics Tracking with Weights & Biases (WandB)
|
||||
|
||||
If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including:
|
||||
|
||||
- **Training Metrics**:
|
||||
- `train/accuracy`
|
||||
- `train/loss`
|
||||
- `train/dataloading_s`
|
||||
- **Evaluation Metrics**:
|
||||
- `eval/accuracy`
|
||||
- `eval/loss`
|
||||
- `eval/eval_s`
|
||||
|
||||
#### Additional Features
|
||||
|
||||
You can also log sample predictions during evaluation. Each logged sample will include:
|
||||
|
||||
- The **input image**.
|
||||
- The **predicted label**.
|
||||
- The **true label**.
|
||||
- The **classifier's "confidence" (logits/probability)**.
|
||||
|
||||
These logs can be useful for diagnosing and debugging performance issues.
|
|
@ -291,7 +291,7 @@ class LeRobotDatasetMetadata:
|
|||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
if robot is not None:
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
features = {**(features or {}), **get_features_from_robot(robot)}
|
||||
robot_type = robot.robot_type
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
|
|
|
@ -31,6 +31,7 @@ from termcolor import colored
|
|||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
import wandb
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
|
||||
|
@ -107,8 +108,6 @@ class Logger:
|
|||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = None
|
||||
if cfg.resume:
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
|
||||
|
@ -232,7 +231,7 @@ class Logger:
|
|||
# TODO(alexander-soare): Add local text log.
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
if not isinstance(v, (int, float, str, wandb.Table)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassifierConfig:
|
||||
"""Configuration for the Classifier model."""
|
||||
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "microsoft/resnet-50"
|
||||
device: str = "cuda" if torch.cuda.is_available() else "mps"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
|
||||
def save_pretrained(self, save_dir):
|
||||
"""Save config to json file."""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Convert to dict and save as JSON
|
||||
config_dict = asdict(self)
|
||||
with open(os.path.join(save_dir, "config.json"), "w") as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path):
|
||||
"""Load config from json file."""
|
||||
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
|
||||
|
||||
with open(config_file) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
return cls(**config_dict)
|
|
@ -0,0 +1,134 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
from .configuration_classifier import ClassifierConfig
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
|
||||
class Classifier(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
# Add Hub metadata
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "vision-classifier"],
|
||||
):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
# Add name attribute for factory
|
||||
name = "classifier"
|
||||
|
||||
def __init__(self, config: ClassifierConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
self.encoder = encoder.vision_model
|
||||
self.vision_config = encoder.config.vision_config
|
||||
else:
|
||||
self.encoder = encoder
|
||||
self.vision_config = getattr(encoder, "config", None)
|
||||
|
||||
# Model type from config
|
||||
self.is_cnn = self.config.model_type == "cnn"
|
||||
|
||||
# For CNNs, initialize backbone
|
||||
if self.is_cnn:
|
||||
self._setup_cnn_backbone()
|
||||
|
||||
self._freeze_encoder()
|
||||
self._build_classifier_head()
|
||||
|
||||
def _setup_cnn_backbone(self):
|
||||
"""Set up CNN encoder"""
|
||||
if hasattr(self.encoder, "fc"):
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _build_classifier_head(self) -> None:
|
||||
"""Initialize the classifier head architecture."""
|
||||
# Get input dimension based on model type
|
||||
if self.is_cnn:
|
||||
input_dim = self.feature_dim
|
||||
else: # Transformer models
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
|
||||
)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
# Process images with the processor (handles resizing and normalization)
|
||||
processed = self.processor(
|
||||
images=x, # LeRobotDataset already provides proper tensor format
|
||||
return_tensors="pt",
|
||||
)
|
||||
processed = processed["pixel_values"].to(x.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
# The HF ResNet applies pooling internally
|
||||
outputs = self.encoder(processed)
|
||||
# Get pooled output directly
|
||||
features = outputs.pooler_output
|
||||
|
||||
if features.dim() > 2:
|
||||
features = features.squeeze(-1).squeeze(-1)
|
||||
return features
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(processed)
|
||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||
return outputs.pooler_output
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier."""
|
||||
# For training, we expect input to be a tensor directly from LeRobotDataset
|
||||
encoder_output = self._get_encoder_output(x)
|
||||
logits = self.classifier_head(encoder_output)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
logits = logits.squeeze(-1)
|
||||
probabilities = torch.sigmoid(logits)
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_output)
|
|
@ -120,14 +120,22 @@ def predict_action(observation, policy, device, use_amp):
|
|||
return action
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
def init_keyboard_listener(assign_rewards=False):
|
||||
"""
|
||||
Initializes a keyboard listener to enable early termination of an episode
|
||||
or environment reset by pressing the right arrow key ('->'). This may require
|
||||
sudo permissions to allow the terminal to monitor keyboard events.
|
||||
|
||||
Args:
|
||||
assign_rewards (bool): If True, allows annotating the collected trajectory
|
||||
with a binary reward at the end of the episode to indicate success.
|
||||
"""
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["stop_recording"] = False
|
||||
if assign_rewards:
|
||||
events["next.reward"] = 0
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
|
@ -152,6 +160,13 @@ def init_keyboard_listener():
|
|||
print("Escape key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
elif assign_rewards and key == keyboard.Key.space:
|
||||
events["next.reward"] = 1 if events["next.reward"] == 0 else 0
|
||||
print(
|
||||
"Space key pressed. Assigning new reward to the subsequent frames. New reward:",
|
||||
events["next.reward"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
|
@ -272,6 +287,8 @@ def control_loop(
|
|||
|
||||
if dataset is not None:
|
||||
frame = {**observation, **action}
|
||||
if "next.reward" in events:
|
||||
frame["next.reward"] = events["next.reward"]
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
|
@ -301,6 +318,8 @@ def reset_environment(robot, events, reset_time_s):
|
|||
|
||||
timestamp = 0
|
||||
start_vencod_t = time.perf_counter()
|
||||
if "next.reward" in events:
|
||||
events["next.reward"] = 0
|
||||
|
||||
# Wait if necessary
|
||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# @package _global_
|
||||
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
seed: 13
|
||||
dataset_repo_id: "dataset_repo_id"
|
||||
train_split_proportion: 0.8
|
||||
|
||||
# Required by logger
|
||||
env:
|
||||
name: "classifier"
|
||||
task: "binary_classification"
|
||||
|
||||
|
||||
training:
|
||||
num_epochs: 5
|
||||
batch_size: 16
|
||||
learning_rate: 1e-4
|
||||
num_workers: 4
|
||||
grad_clip_norm: 10
|
||||
use_amp: true
|
||||
log_freq: 1
|
||||
eval_freq: 1 # How often to run validation (in epochs)
|
||||
save_freq: 1 # How often to save checkpoints (in epochs)
|
||||
save_checkpoint: true
|
||||
image_key: "observation.images.phone"
|
||||
label_key: "next.reward"
|
||||
|
||||
eval:
|
||||
batch_size: 16
|
||||
num_samples_to_log: 30 # Number of validation samples to log in the table
|
||||
|
||||
policy:
|
||||
name: "hilserl/classifier"
|
||||
model_name: "facebook/convnext-base-224"
|
||||
model_type: "cnn"
|
||||
|
||||
wandb:
|
||||
enable: false
|
||||
project: "classifier-training"
|
||||
entity: "wandb_entity"
|
||||
job_name: "classifier_training_0"
|
||||
disable_artifact: false
|
||||
|
||||
device: "mps"
|
||||
resume: false
|
||||
output_dir: "output"
|
|
@ -191,6 +191,7 @@ def record(
|
|||
single_task: str,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: List[str] | None = None,
|
||||
assign_rewards: bool = False,
|
||||
fps: int | None = None,
|
||||
warmup_time_s: int | float = 2,
|
||||
episode_time_s: int | float = 10,
|
||||
|
@ -214,6 +215,9 @@ def record(
|
|||
policy = None
|
||||
device = None
|
||||
use_amp = None
|
||||
extra_features = (
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
||||
)
|
||||
|
||||
if single_task:
|
||||
task = single_task
|
||||
|
@ -254,12 +258,12 @@ def record(
|
|||
use_videos=video,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
features=extra_features,
|
||||
)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||
|
||||
# Execute a few seconds without recording to:
|
||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||
|
@ -469,12 +473,12 @@ if __name__ == "__main__":
|
|||
default=1,
|
||||
help="Upload dataset to Hugging Face hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--tags",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Add tags to your dataset on the hub.",
|
||||
)
|
||||
# parser_record.add_argument(
|
||||
# "--tags",
|
||||
# type=str,
|
||||
# nargs="*",
|
||||
# help="Add tags to your dataset on the hub.",
|
||||
# )
|
||||
parser_record.add_argument(
|
||||
"--num-image-writer-processes",
|
||||
type=int,
|
||||
|
@ -517,6 +521,12 @@ if __name__ == "__main__":
|
|||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--assign-rewards",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
|
|
|
@ -0,0 +1,310 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import optim
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
|
||||
def get_model(cfg, logger):
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
model = Classifier(classifier_config)
|
||||
if cfg.resume:
|
||||
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
|
||||
return model
|
||||
|
||||
|
||||
def create_balanced_sampler(dataset, cfg):
|
||||
# Creates a weighted sampler to handle class imbalance
|
||||
|
||||
labels = torch.tensor([item[cfg.training.label_key] for item in dataset])
|
||||
_, counts = torch.unique(labels, return_counts=True)
|
||||
class_weights = 1.0 / counts.float()
|
||||
sample_weights = class_weights[labels]
|
||||
|
||||
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
||||
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
||||
# Single epoch training loop with AMP support and progress tracking
|
||||
model.train()
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
pbar = tqdm(train_loader, desc="Training")
|
||||
for batch_idx, batch in enumerate(pbar):
|
||||
start_time = time.perf_counter()
|
||||
images = batch[cfg.training.image_key].to(device)
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
# Backward pass with gradient scaling if AMP enabled
|
||||
optimizer.zero_grad()
|
||||
if cfg.training.use_amp:
|
||||
grad_scaler.scale(loss).backward()
|
||||
grad_scaler.step(optimizer)
|
||||
grad_scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Track metrics
|
||||
if model.config.num_classes == 2:
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
|
||||
current_acc = 100 * correct / total
|
||||
train_info = {
|
||||
"loss": loss.item(),
|
||||
"accuracy": current_acc,
|
||||
"dataloading_s": time.perf_counter() - start_time,
|
||||
}
|
||||
|
||||
logger.log_dict(train_info, step + batch_idx, mode="train")
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
|
||||
|
||||
|
||||
def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8):
|
||||
# Validation loop with metric tracking and sample logging
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
batch_start_time = time.perf_counter()
|
||||
samples = []
|
||||
running_loss = 0
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
|
||||
for batch in tqdm(val_loader, desc="Validation"):
|
||||
images = batch[cfg.training.image_key].to(device)
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
# Track metrics
|
||||
if model.config.num_classes == 2:
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
running_loss += loss.item()
|
||||
|
||||
# Log sample predictions for visualization
|
||||
if len(samples) < num_samples_to_log:
|
||||
for i in range(min(num_samples_to_log - len(samples), len(images))):
|
||||
if model.config.num_classes == 2:
|
||||
confidence = round(outputs.probabilities[i].item(), 3)
|
||||
else:
|
||||
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
||||
samples.append(
|
||||
{
|
||||
"image": wandb.Image(images[i].cpu()),
|
||||
"true_label": labels[i].item(),
|
||||
"predicted": predictions[i].item(),
|
||||
"confidence": confidence,
|
||||
}
|
||||
)
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
avg_loss = running_loss / len(val_loader)
|
||||
|
||||
eval_info = {
|
||||
"loss": avg_loss,
|
||||
"accuracy": accuracy,
|
||||
"eval_s": time.perf_counter() - batch_start_time,
|
||||
"eval/prediction_samples": wandb.Table(
|
||||
data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples],
|
||||
columns=["Image", "True Label", "Predicted", "Confidence"],
|
||||
)
|
||||
if logger._cfg.wandb.enable
|
||||
else None,
|
||||
}
|
||||
|
||||
return accuracy, eval_info
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_path="../configs", config_name="classifier")
|
||||
def train(cfg: DictConfig) -> None:
|
||||
# Main training pipeline with support for resuming training
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
# Initialize training environment
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
out_dir = Path(cfg.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
|
||||
|
||||
# Setup dataset and dataloaders
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||
logging.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
train_size = int(cfg.train_split_proportion * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
|
||||
sampler = create_balanced_sampler(train_dataset, cfg)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=cfg.training.batch_size,
|
||||
num_workers=cfg.training.num_workers,
|
||||
sampler=sampler,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=cfg.eval.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=cfg.training.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# Resume training if requested
|
||||
step = 0
|
||||
best_val_acc = 0
|
||||
|
||||
if cfg.resume:
|
||||
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
# Load and validate checkpoint configuration
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||
resolve_delta_timestamps(cfg)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
"At least one difference was detected between the checkpoint configuration and "
|
||||
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
|
||||
"takes precedence.",
|
||||
)
|
||||
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
|
||||
cfg = checkpoint_cfg
|
||||
cfg.resume = True
|
||||
|
||||
# Initialize model and training components
|
||||
model = get_model(cfg=cfg, logger=logger).to(device)
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
|
||||
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
|
||||
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
|
||||
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
|
||||
|
||||
# Log model parameters
|
||||
num_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in model.parameters())
|
||||
logging.info(f"Learnable parameters: {format_big_number(num_learnable_params)}")
|
||||
logging.info(f"Total parameters: {format_big_number(num_total_params)}")
|
||||
|
||||
if cfg.resume:
|
||||
step = logger.load_last_training_state(optimizer, None)
|
||||
|
||||
# Training loop with validation and checkpointing
|
||||
for epoch in range(cfg.training.num_epochs):
|
||||
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
|
||||
|
||||
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg)
|
||||
|
||||
# Periodic validation
|
||||
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:
|
||||
val_acc, eval_info = validate(
|
||||
model,
|
||||
val_loader,
|
||||
criterion,
|
||||
device,
|
||||
logger,
|
||||
cfg,
|
||||
)
|
||||
logger.log_dict(eval_info, step + len(train_loader), mode="eval")
|
||||
|
||||
# Save best model
|
||||
if val_acc > best_val_acc:
|
||||
best_val_acc = val_acc
|
||||
logger.save_checkpoint(
|
||||
train_step=step + len(train_loader),
|
||||
policy=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=None,
|
||||
identifier="best",
|
||||
)
|
||||
|
||||
# Periodic checkpointing
|
||||
if cfg.training.save_checkpoint and (epoch + 1) % cfg.training.save_freq == 0:
|
||||
logger.save_checkpoint(
|
||||
train_step=step + len(train_loader),
|
||||
policy=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=None,
|
||||
identifier=f"{epoch+1:06d}",
|
||||
)
|
||||
|
||||
step += len(train_loader)
|
||||
|
||||
logging.info("Training completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
|
@ -0,0 +1,251 @@
|
|||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from hydra import compose, initialize_config_dir
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.scripts.train_hilserl_classifier import (
|
||||
create_balanced_sampler,
|
||||
train,
|
||||
train_epoch,
|
||||
validate,
|
||||
)
|
||||
|
||||
|
||||
class MockDataset(Dataset):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.meta = MagicMock()
|
||||
self.meta.stats = {}
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def make_dummy_model():
|
||||
model_config = ClassifierConfig(num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel")
|
||||
model = Classifier(config=model_config)
|
||||
return model
|
||||
|
||||
|
||||
def test_create_balanced_sampler():
|
||||
# Mock dataset with imbalanced classes
|
||||
data = [
|
||||
{"label": 0},
|
||||
{"label": 0},
|
||||
{"label": 1},
|
||||
{"label": 0},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
]
|
||||
dataset = MockDataset(data)
|
||||
cfg = MagicMock()
|
||||
cfg.training.label_key = "label"
|
||||
|
||||
sampler = create_balanced_sampler(dataset, cfg)
|
||||
|
||||
# Get weights from the sampler
|
||||
weights = sampler.weights.float()
|
||||
|
||||
# Check that samples have appropriate weights
|
||||
labels = [item["label"] for item in data]
|
||||
class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32)
|
||||
class_weights = 1.0 / class_counts
|
||||
expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32)
|
||||
|
||||
# Test that the weights are correct
|
||||
assert torch.allclose(weights, expected_weights)
|
||||
|
||||
|
||||
def test_train_epoch():
|
||||
model = make_dummy_model()
|
||||
# Mock components
|
||||
model.train = MagicMock()
|
||||
|
||||
train_loader = [
|
||||
{
|
||||
"image": torch.rand(2, 3, 224, 224),
|
||||
"label": torch.tensor([0.0, 1.0]),
|
||||
}
|
||||
]
|
||||
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
optimizer = MagicMock()
|
||||
grad_scaler = MagicMock()
|
||||
device = torch.device("cpu")
|
||||
logger = MagicMock()
|
||||
step = 0
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_key = "image"
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
# Call the function under test
|
||||
train_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
criterion,
|
||||
optimizer,
|
||||
grad_scaler,
|
||||
device,
|
||||
logger,
|
||||
step,
|
||||
cfg,
|
||||
)
|
||||
|
||||
# Check that model.train() was called
|
||||
model.train.assert_called_once()
|
||||
|
||||
# Check that optimizer.zero_grad() was called
|
||||
optimizer.zero_grad.assert_called()
|
||||
|
||||
# Check that logger.log_dict was called
|
||||
logger.log_dict.assert_called()
|
||||
|
||||
|
||||
def test_validate():
|
||||
model = make_dummy_model()
|
||||
|
||||
# Mock components
|
||||
model.eval = MagicMock()
|
||||
val_loader = [
|
||||
{
|
||||
"image": torch.rand(2, 3, 224, 224),
|
||||
"label": torch.tensor([0.0, 1.0]),
|
||||
}
|
||||
]
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
device = torch.device("cpu")
|
||||
logger = MagicMock()
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_key = "image"
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
# Call validate
|
||||
accuracy, eval_info = validate(model, val_loader, criterion, device, logger, cfg)
|
||||
|
||||
# Check that model.eval() was called
|
||||
model.eval.assert_called_once()
|
||||
|
||||
# Check accuracy/eval_info are calculated and of the correct type
|
||||
assert isinstance(accuracy, float)
|
||||
assert isinstance(eval_info, dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("resume", [True, False])
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.make_policy")
|
||||
def test_resume_function(
|
||||
mock_make_policy,
|
||||
mock_dataset,
|
||||
mock_logger,
|
||||
mock_get_last_pretrained_model_dir,
|
||||
mock_get_last_checkpoint_dir,
|
||||
mock_init_hydra_config,
|
||||
resume,
|
||||
):
|
||||
# Initialize Hydra
|
||||
test_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy"))
|
||||
assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}"
|
||||
|
||||
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
|
||||
cfg = compose(
|
||||
config_name="reward_classifier",
|
||||
overrides=[
|
||||
"device=cpu",
|
||||
"seed=42",
|
||||
f"output_dir={tempfile.mkdtemp()}",
|
||||
"wandb.enable=False",
|
||||
f"resume={resume}",
|
||||
"dataset_repo_id=dataset_repo_id",
|
||||
"train_split_proportion=0.8",
|
||||
"training.num_workers=0",
|
||||
"training.batch_size=2",
|
||||
"training.image_key=image",
|
||||
"training.label_key=label",
|
||||
"training.use_amp=False",
|
||||
"training.num_epochs=1",
|
||||
"eval.batch_size=2",
|
||||
],
|
||||
)
|
||||
|
||||
# Mock the init_hydra_config function to return cfg
|
||||
mock_init_hydra_config.return_value = cfg
|
||||
|
||||
# Mock dataset
|
||||
dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)])
|
||||
mock_dataset.return_value = dataset
|
||||
|
||||
# Mock checkpoint handling
|
||||
mock_checkpoint_dir = MagicMock(spec=Path)
|
||||
mock_checkpoint_dir.exists.return_value = resume # Only exists if resuming
|
||||
mock_get_last_checkpoint_dir.return_value = mock_checkpoint_dir
|
||||
mock_get_last_pretrained_model_dir.return_value = Path(tempfile.mkdtemp())
|
||||
|
||||
# Mock logger
|
||||
logger = MagicMock()
|
||||
resumed_step = 1000
|
||||
if resume:
|
||||
logger.load_last_training_state.return_value = resumed_step
|
||||
else:
|
||||
logger.load_last_training_state.return_value = 0
|
||||
mock_logger.return_value = logger
|
||||
|
||||
# Instantiate the model and set make_policy to return it
|
||||
model = make_dummy_model()
|
||||
mock_make_policy.return_value = model
|
||||
|
||||
# Call train
|
||||
train(cfg)
|
||||
|
||||
# Check that checkpoint handling methods were called
|
||||
if resume:
|
||||
mock_get_last_checkpoint_dir.assert_called_once_with(Path(cfg.output_dir))
|
||||
mock_get_last_pretrained_model_dir.assert_called_once_with(Path(cfg.output_dir))
|
||||
mock_checkpoint_dir.exists.assert_called_once()
|
||||
logger.load_last_training_state.assert_called_once()
|
||||
else:
|
||||
mock_get_last_checkpoint_dir.assert_not_called()
|
||||
mock_get_last_pretrained_model_dir.assert_not_called()
|
||||
mock_checkpoint_dir.exists.assert_not_called()
|
||||
logger.load_last_training_state.assert_not_called()
|
||||
|
||||
# Collect the steps from logger.log_dict calls
|
||||
train_log_calls = logger.log_dict.call_args_list
|
||||
|
||||
# Extract the steps used in the train logging
|
||||
steps = []
|
||||
for call in train_log_calls:
|
||||
mode = call.kwargs.get("mode", call.args[2] if len(call.args) > 2 else None)
|
||||
if mode == "train":
|
||||
step = call.kwargs.get("step", call.args[1] if len(call.args) > 1 else None)
|
||||
steps.append(step)
|
||||
|
||||
expected_start_step = resumed_step if resume else 0
|
||||
|
||||
# Calculate expected_steps
|
||||
train_size = int(cfg.train_split_proportion * len(dataset))
|
||||
batch_size = cfg.training.batch_size
|
||||
num_batches = (train_size + batch_size - 1) // batch_size
|
||||
|
||||
expected_steps = [expected_start_step + i for i in range(num_batches)]
|
||||
|
||||
assert steps == expected_steps, f"Expected steps {expected_steps}, got {steps}"
|
Loading…
Reference in New Issue