Reward classifier and training (#528)

Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai>
Co-authored-by: resolver101757 <kelster101757@hotmail.com>
Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com>
Co-authored-by: Remi <re.cadene@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Yoel 2024-12-09 10:21:50 +01:00 committed by Adil Zouitine
parent d037f4a322
commit 0ebdae8a40
10 changed files with 1131 additions and 7 deletions

View File

@ -0,0 +1,83 @@
# Training a HIL-SERL Reward Classifier with LeRobot
This tutorial provides step-by-step instructions for training a reward classifier using LeRobot.
---
## Training Script Overview
LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow:
1. **Configuration Loading**
The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.)
2. **Dataset Initialization**
It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling.
3. **Classifier Initialization**
A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either:
- A single probability (binary classification), or
- Logits (multi-class classification).
4. **Training Loop Execution**
The script performs:
- Forward and backward passes,
- Optimization steps,
- Periodic logging, evaluation, and checkpoint saving.
---
## Configuring with Hydra
For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file.
### Config File Setup
The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment:
Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards:
```yaml
# Example: lerobot/configs/policy/reward_classifier.yaml
dataset_repo_id: "my_dataset_repo_id"
## Typical logs and metrics
```
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
After that, you will see training log like this one:
```
[2024-11-29 18:26:36,999][root][INFO] -
Epoch 5/5
Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%]
```
or evaluation log like:
```
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s]
```
### Metrics Tracking with Weights & Biases (WandB)
If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including:
- **Training Metrics**:
- `train/accuracy`
- `train/loss`
- `train/dataloading_s`
- **Evaluation Metrics**:
- `eval/accuracy`
- `eval/loss`
- `eval/eval_s`
#### Additional Features
You can also log sample predictions during evaluation. Each logged sample will include:
- The **input image**.
- The **predicted label**.
- The **true label**.
- The **classifier's "confidence" (logits/probability)**.
These logs can be useful for diagnosing and debugging performance issues.

View File

@ -318,7 +318,7 @@ class LeRobotDatasetMetadata:
obj.root.mkdir(parents=True, exist_ok=False)
if robot is not None:
features = get_features_from_robot(robot, use_videos)
features = {**(features or {}), **get_features_from_robot(robot)}
robot_type = robot.robot_type
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(

245
lerobot/common/logger.py Normal file
View File

@ -0,0 +1,245 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
# TODO(rcadene, alexander-soare): clean this file
"""
import logging
import os
import re
from glob import glob
from pathlib import Path
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
import wandb
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str:
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.name}",
f"dataset:{cfg.dataset_repo_id}",
f"env:{cfg.env.name}",
f"seed:{cfg.seed}",
]
return lst if return_list else "-".join(lst)
def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str:
# Get the WandB run ID.
paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*"))
if len(paths) != 1:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
if match is None:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
wandb_run_id = match.groups(0)[0]
return wandb_run_id
class Logger:
"""Primary logger object. Logs either locally or using wandb.
The logger creates the following directory structure:
provided_log_dir
.hydra # hydra's configuration cache
checkpoints
specific_checkpoint_name
pretrained_model # Hugging Face pretrained model directory
...
training_state.pth # optimizer, scheduler, and random states + training step
| another_specific_checkpoint_name
...
| ...
last # a softlink to the last logged checkpoint
"""
pretrained_model_dir_name = "pretrained_model"
training_state_file_name = "training_state.pth"
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
"""
Args:
log_dir: The directory to save all logs and training outputs to.
job_name: The WandB job name.
"""
self._cfg = cfg
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
self.checkpoints_dir = self.get_checkpoints_dir(log_dir)
self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir)
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir)
# Set up WandB.
self._group = cfg_to_group(cfg)
project = cfg.get("wandb", {}).get("project")
entity = cfg.get("wandb", {}).get("entity")
enable_wandb = cfg.get("wandb", {}).get("enable", False)
run_offline = not enable_wandb or not project
if run_offline:
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None
else:
os.environ["WANDB_SILENT"] = "true"
wandb_run_id = None
if cfg.resume:
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
wandb.init(
id=wandb_run_id,
project=project,
entity=entity,
name=wandb_job_name,
notes=cfg.get("wandb", {}).get("notes"),
tags=cfg_to_group(cfg, return_list=True),
dir=log_dir,
config=OmegaConf.to_container(cfg, resolve=True),
# TODO(rcadene): try set to True
save_code=False,
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
job_type="train_eval",
resume="must" if cfg.resume else None,
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
@classmethod
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
"""Given the log directory, get the sub-directory in which checkpoints will be saved."""
return Path(log_dir) / "checkpoints"
@classmethod
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path:
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved."""
return cls.get_checkpoints_dir(log_dir) / "last"
@classmethod
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
"""
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
be saved.
"""
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
"""Save the weights of the Policy model using PyTorchModelHubMixin.
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
Optionally also upload the model to WandB.
"""
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
policy.save_pretrained(save_dir)
# Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml")
if self._wandb and not self._cfg.wandb.disable_artifact:
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(wandb_artifact_name, type="model")
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
if self.last_checkpoint_dir.exists():
os.remove(self.last_checkpoint_dir)
def save_training_state(
self,
save_dir: Path,
train_step: int,
optimizer: Optimizer,
scheduler: LRScheduler | None,
):
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
All of these are saved as "training_state.pth" under the checkpoint directory.
"""
training_state = {
"step": train_step,
"optimizer": optimizer.state_dict(),
**get_global_random_state(),
}
if scheduler is not None:
training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, save_dir / self.training_state_file_name)
def save_checkpoint(
self,
train_step: int,
policy: Policy,
optimizer: Optimizer,
scheduler: LRScheduler | None,
identifier: str,
):
"""Checkpoint the model weights and the training state."""
checkpoint_dir = self.checkpoints_dir / str(identifier)
wandb_artifact_name = (
None
if self._wandb is None
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
)
self.save_model(
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
"""
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
random state, and return the global training step.
"""
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state:
raise ValueError(
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
)
# Small hack to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"]
def log_dict(self, d, step, mode="train"):
assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log.
if self._wandb is not None:
for k, v in d.items():
if not isinstance(v, (int, float, str, wandb.Table)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"}
assert self._wandb is not None
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)

View File

@ -0,0 +1,36 @@
import json
import os
from dataclasses import asdict, dataclass
import torch
@dataclass
class ClassifierConfig:
"""Configuration for the Classifier model."""
num_classes: int = 2
hidden_dim: int = 256
dropout_rate: float = 0.1
model_name: str = "microsoft/resnet-50"
device: str = "cuda" if torch.cuda.is_available() else "mps"
model_type: str = "cnn" # "transformer" or "cnn"
def save_pretrained(self, save_dir):
"""Save config to json file."""
os.makedirs(save_dir, exist_ok=True)
# Convert to dict and save as JSON
config_dict = asdict(self)
with open(os.path.join(save_dir, "config.json"), "w") as f:
json.dump(config_dict, f, indent=2)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path):
"""Load config from json file."""
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
with open(config_file) as f:
config_dict = json.load(f)
return cls(**config_dict)

View File

@ -0,0 +1,134 @@
import logging
from typing import Optional
import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from transformers import AutoImageProcessor, AutoModel
from .configuration_classifier import ClassifierConfig
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""
def __init__(
self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None
):
self.logits = logits
self.probabilities = probabilities
self.hidden_states = hidden_states
class Classifier(
nn.Module,
PyTorchModelHubMixin,
# Add Hub metadata
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "vision-classifier"],
):
"""Image classifier built on top of a pre-trained encoder."""
# Add name attribute for factory
name = "classifier"
def __init__(self, config: ClassifierConfig):
super().__init__()
self.config = config
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"):
logging.info("Multimodal model detected - using vision encoder only")
self.encoder = encoder.vision_model
self.vision_config = encoder.config.vision_config
else:
self.encoder = encoder
self.vision_config = getattr(encoder, "config", None)
# Model type from config
self.is_cnn = self.config.model_type == "cnn"
# For CNNs, initialize backbone
if self.is_cnn:
self._setup_cnn_backbone()
self._freeze_encoder()
self._build_classifier_head()
def _setup_cnn_backbone(self):
"""Set up CNN encoder"""
if hasattr(self.encoder, "fc"):
self.feature_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
elif hasattr(self.encoder.config, "hidden_sizes"):
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
def _freeze_encoder(self) -> None:
"""Freeze the encoder parameters."""
for param in self.encoder.parameters():
param.requires_grad = False
def _build_classifier_head(self) -> None:
"""Initialize the classifier head architecture."""
# Get input dimension based on model type
if self.is_cnn:
input_dim = self.feature_dim
else: # Transformer models
if hasattr(self.encoder.config, "hidden_size"):
input_dim = self.encoder.config.hidden_size
else:
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
self.classifier_head = nn.Sequential(
nn.Linear(input_dim, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(),
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
)
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
"""Extract the appropriate output from the encoder."""
# Process images with the processor (handles resizing and normalization)
processed = self.processor(
images=x, # LeRobotDataset already provides proper tensor format
return_tensors="pt",
)
processed = processed["pixel_values"].to(x.device)
with torch.no_grad():
if self.is_cnn:
# The HF ResNet applies pooling internally
outputs = self.encoder(processed)
# Get pooled output directly
features = outputs.pooler_output
if features.dim() > 2:
features = features.squeeze(-1).squeeze(-1)
return features
else: # Transformer models
outputs = self.encoder(processed)
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :]
def forward(self, x: torch.Tensor) -> ClassifierOutput:
"""Forward pass of the classifier."""
# For training, we expect input to be a tensor directly from LeRobotDataset
encoder_output = self._get_encoder_output(x)
logits = self.classifier_head(encoder_output)
if self.config.num_classes == 2:
logits = logits.squeeze(-1)
probabilities = torch.sigmoid(logits)
else:
probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_output)

View File

@ -128,14 +128,22 @@ def predict_action(observation, policy, device, use_amp):
return action
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
def init_keyboard_listener(assign_rewards=False):
"""
Initializes a keyboard listener to enable early termination of an episode
or environment reset by pressing the right arrow key ('->'). This may require
sudo permissions to allow the terminal to monitor keyboard events.
Args:
assign_rewards (bool): If True, allows annotating the collected trajectory
with a binary reward at the end of the episode to indicate success.
"""
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if assign_rewards:
events["next.reward"] = 0
if is_headless():
logging.warning(
@ -160,6 +168,13 @@ def init_keyboard_listener():
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
elif assign_rewards and key == keyboard.Key.space:
events["next.reward"] = 1 if events["next.reward"] == 0 else 0
print(
"Space key pressed. Assigning new reward to the subsequent frames. New reward:",
events["next.reward"],
)
except Exception as e:
print(f"Error handling key press: {e}")

View File

@ -0,0 +1,48 @@
# @package _global_
defaults:
- _self_
seed: 13
dataset_repo_id: "dataset_repo_id"
train_split_proportion: 0.8
# Required by logger
env:
name: "classifier"
task: "binary_classification"
training:
num_epochs: 5
batch_size: 16
learning_rate: 1e-4
num_workers: 4
grad_clip_norm: 10
use_amp: true
log_freq: 1
eval_freq: 1 # How often to run validation (in epochs)
save_freq: 1 # How often to save checkpoints (in epochs)
save_checkpoint: true
image_key: "observation.images.phone"
label_key: "next.reward"
eval:
batch_size: 16
num_samples_to_log: 30 # Number of validation samples to log in the table
policy:
name: "hilserl/classifier"
model_name: "facebook/convnext-base-224"
model_type: "cnn"
wandb:
enable: false
project: "classifier-training"
entity: "wandb_entity"
job_name: "classifier_training_0"
disable_artifact: false
device: "mps"
resume: false
output_dir: "output"

View File

@ -269,10 +269,12 @@ def record(
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
if not robot.is_connected:
robot.connect()
listener, events = init_keyboard_listener()
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
# Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided,

View File

@ -0,0 +1,310 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from contextlib import nullcontext
from pathlib import Path
from pprint import pformat
import hydra
import torch
import torch.nn as nn
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import optim
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
set_global_seed,
)
def get_model(cfg, logger):
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
model = Classifier(classifier_config)
if cfg.resume:
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
return model
def create_balanced_sampler(dataset, cfg):
# Creates a weighted sampler to handle class imbalance
labels = torch.tensor([item[cfg.training.label_key] for item in dataset])
_, counts = torch.unique(labels, return_counts=True)
class_weights = 1.0 / counts.float()
sample_weights = class_weights[labels]
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
# Single epoch training loop with AMP support and progress tracking
model.train()
correct = 0
total = 0
pbar = tqdm(train_loader, desc="Training")
for batch_idx, batch in enumerate(pbar):
start_time = time.perf_counter()
images = batch[cfg.training.image_key].to(device)
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
outputs = model(images)
loss = criterion(outputs.logits, labels)
# Backward pass with gradient scaling if AMP enabled
optimizer.zero_grad()
if cfg.training.use_amp:
grad_scaler.scale(loss).backward()
grad_scaler.step(optimizer)
grad_scaler.update()
else:
loss.backward()
optimizer.step()
# Track metrics
if model.config.num_classes == 2:
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
else:
predictions = torch.argmax(outputs.logits, dim=1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
current_acc = 100 * correct / total
train_info = {
"loss": loss.item(),
"accuracy": current_acc,
"dataloading_s": time.perf_counter() - start_time,
}
logger.log_dict(train_info, step + batch_idx, mode="train")
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8):
# Validation loop with metric tracking and sample logging
model.eval()
correct = 0
total = 0
batch_start_time = time.perf_counter()
samples = []
running_loss = 0
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
for batch in tqdm(val_loader, desc="Validation"):
images = batch[cfg.training.image_key].to(device)
labels = batch[cfg.training.label_key].float().to(device)
outputs = model(images)
loss = criterion(outputs.logits, labels)
# Track metrics
if model.config.num_classes == 2:
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
else:
predictions = torch.argmax(outputs.logits, dim=1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
running_loss += loss.item()
# Log sample predictions for visualization
if len(samples) < num_samples_to_log:
for i in range(min(num_samples_to_log - len(samples), len(images))):
if model.config.num_classes == 2:
confidence = round(outputs.probabilities[i].item(), 3)
else:
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
samples.append(
{
"image": wandb.Image(images[i].cpu()),
"true_label": labels[i].item(),
"predicted": predictions[i].item(),
"confidence": confidence,
}
)
accuracy = 100 * correct / total
avg_loss = running_loss / len(val_loader)
eval_info = {
"loss": avg_loss,
"accuracy": accuracy,
"eval_s": time.perf_counter() - batch_start_time,
"eval/prediction_samples": wandb.Table(
data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples],
columns=["Image", "True Label", "Predicted", "Confidence"],
)
if logger._cfg.wandb.enable
else None,
}
return accuracy, eval_info
@hydra.main(version_base="1.2", config_path="../configs", config_name="classifier")
def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training
logging.info(OmegaConf.to_yaml(cfg))
# Initialize training environment
device = get_safe_torch_device(cfg.device, log=True)
set_global_seed(cfg.seed)
out_dir = Path(cfg.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
# Setup dataset and dataloaders
dataset = LeRobotDataset(cfg.dataset_repo_id)
logging.info(f"Dataset size: {len(dataset)}")
train_size = int(cfg.train_split_proportion * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
sampler = create_balanced_sampler(train_dataset, cfg)
train_loader = DataLoader(
train_dataset,
batch_size=cfg.training.batch_size,
num_workers=cfg.training.num_workers,
sampler=sampler,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=cfg.eval.batch_size,
shuffle=False,
num_workers=cfg.training.num_workers,
pin_memory=True,
)
# Resume training if requested
step = 0
best_val_acc = 0
if cfg.resume:
if not Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
color="yellow",
attrs=["bold"],
)
)
# Load and validate checkpoint configuration
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
# Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
if len(diff) > 0:
logging.warning(
"At least one difference was detected between the checkpoint configuration and "
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
"takes precedence.",
)
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
cfg = checkpoint_cfg
cfg.resume = True
# Initialize model and training components
model = get_model(cfg=cfg, logger=logger).to(device)
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
# Log model parameters
num_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in model.parameters())
logging.info(f"Learnable parameters: {format_big_number(num_learnable_params)}")
logging.info(f"Total parameters: {format_big_number(num_total_params)}")
if cfg.resume:
step = logger.load_last_training_state(optimizer, None)
# Training loop with validation and checkpointing
for epoch in range(cfg.training.num_epochs):
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg)
# Periodic validation
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:
val_acc, eval_info = validate(
model,
val_loader,
criterion,
device,
logger,
cfg,
)
logger.log_dict(eval_info, step + len(train_loader), mode="eval")
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
logger.save_checkpoint(
train_step=step + len(train_loader),
policy=model,
optimizer=optimizer,
scheduler=None,
identifier="best",
)
# Periodic checkpointing
if cfg.training.save_checkpoint and (epoch + 1) % cfg.training.save_freq == 0:
logger.save_checkpoint(
train_step=step + len(train_loader),
policy=model,
optimizer=optimizer,
scheduler=None,
identifier=f"{epoch+1:06d}",
)
step += len(train_loader)
logging.info("Training completed")
if __name__ == "__main__":
train()

View File

@ -0,0 +1,251 @@
import os
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import torch
from hydra import compose, initialize_config_dir
from torch import nn
from torch.utils.data import Dataset
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.scripts.train_hilserl_classifier import (
create_balanced_sampler,
train,
train_epoch,
validate,
)
class MockDataset(Dataset):
def __init__(self, data):
self.data = data
self.meta = MagicMock()
self.meta.stats = {}
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
def make_dummy_model():
model_config = ClassifierConfig(num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel")
model = Classifier(config=model_config)
return model
def test_create_balanced_sampler():
# Mock dataset with imbalanced classes
data = [
{"label": 0},
{"label": 0},
{"label": 1},
{"label": 0},
{"label": 1},
{"label": 1},
{"label": 1},
{"label": 1},
]
dataset = MockDataset(data)
cfg = MagicMock()
cfg.training.label_key = "label"
sampler = create_balanced_sampler(dataset, cfg)
# Get weights from the sampler
weights = sampler.weights.float()
# Check that samples have appropriate weights
labels = [item["label"] for item in data]
class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32)
class_weights = 1.0 / class_counts
expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32)
# Test that the weights are correct
assert torch.allclose(weights, expected_weights)
def test_train_epoch():
model = make_dummy_model()
# Mock components
model.train = MagicMock()
train_loader = [
{
"image": torch.rand(2, 3, 224, 224),
"label": torch.tensor([0.0, 1.0]),
}
]
criterion = nn.BCEWithLogitsLoss()
optimizer = MagicMock()
grad_scaler = MagicMock()
device = torch.device("cpu")
logger = MagicMock()
step = 0
cfg = MagicMock()
cfg.training.image_key = "image"
cfg.training.label_key = "label"
cfg.training.use_amp = False
# Call the function under test
train_epoch(
model,
train_loader,
criterion,
optimizer,
grad_scaler,
device,
logger,
step,
cfg,
)
# Check that model.train() was called
model.train.assert_called_once()
# Check that optimizer.zero_grad() was called
optimizer.zero_grad.assert_called()
# Check that logger.log_dict was called
logger.log_dict.assert_called()
def test_validate():
model = make_dummy_model()
# Mock components
model.eval = MagicMock()
val_loader = [
{
"image": torch.rand(2, 3, 224, 224),
"label": torch.tensor([0.0, 1.0]),
}
]
criterion = nn.BCEWithLogitsLoss()
device = torch.device("cpu")
logger = MagicMock()
cfg = MagicMock()
cfg.training.image_key = "image"
cfg.training.label_key = "label"
cfg.training.use_amp = False
# Call validate
accuracy, eval_info = validate(model, val_loader, criterion, device, logger, cfg)
# Check that model.eval() was called
model.eval.assert_called_once()
# Check accuracy/eval_info are calculated and of the correct type
assert isinstance(accuracy, float)
assert isinstance(eval_info, dict)
@pytest.mark.parametrize("resume", [True, False])
@patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config")
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir")
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
@patch("lerobot.scripts.train_hilserl_classifier.make_policy")
def test_resume_function(
mock_make_policy,
mock_dataset,
mock_logger,
mock_get_last_pretrained_model_dir,
mock_get_last_checkpoint_dir,
mock_init_hydra_config,
resume,
):
# Initialize Hydra
test_file_dir = os.path.dirname(os.path.abspath(__file__))
config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy"))
assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}"
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
cfg = compose(
config_name="reward_classifier",
overrides=[
"device=cpu",
"seed=42",
f"output_dir={tempfile.mkdtemp()}",
"wandb.enable=False",
f"resume={resume}",
"dataset_repo_id=dataset_repo_id",
"train_split_proportion=0.8",
"training.num_workers=0",
"training.batch_size=2",
"training.image_key=image",
"training.label_key=label",
"training.use_amp=False",
"training.num_epochs=1",
"eval.batch_size=2",
],
)
# Mock the init_hydra_config function to return cfg
mock_init_hydra_config.return_value = cfg
# Mock dataset
dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)])
mock_dataset.return_value = dataset
# Mock checkpoint handling
mock_checkpoint_dir = MagicMock(spec=Path)
mock_checkpoint_dir.exists.return_value = resume # Only exists if resuming
mock_get_last_checkpoint_dir.return_value = mock_checkpoint_dir
mock_get_last_pretrained_model_dir.return_value = Path(tempfile.mkdtemp())
# Mock logger
logger = MagicMock()
resumed_step = 1000
if resume:
logger.load_last_training_state.return_value = resumed_step
else:
logger.load_last_training_state.return_value = 0
mock_logger.return_value = logger
# Instantiate the model and set make_policy to return it
model = make_dummy_model()
mock_make_policy.return_value = model
# Call train
train(cfg)
# Check that checkpoint handling methods were called
if resume:
mock_get_last_checkpoint_dir.assert_called_once_with(Path(cfg.output_dir))
mock_get_last_pretrained_model_dir.assert_called_once_with(Path(cfg.output_dir))
mock_checkpoint_dir.exists.assert_called_once()
logger.load_last_training_state.assert_called_once()
else:
mock_get_last_checkpoint_dir.assert_not_called()
mock_get_last_pretrained_model_dir.assert_not_called()
mock_checkpoint_dir.exists.assert_not_called()
logger.load_last_training_state.assert_not_called()
# Collect the steps from logger.log_dict calls
train_log_calls = logger.log_dict.call_args_list
# Extract the steps used in the train logging
steps = []
for call in train_log_calls:
mode = call.kwargs.get("mode", call.args[2] if len(call.args) > 2 else None)
if mode == "train":
step = call.kwargs.get("step", call.args[1] if len(call.args) > 1 else None)
steps.append(step)
expected_start_step = resumed_step if resume else 0
# Calculate expected_steps
train_size = int(cfg.train_split_proportion * len(dataset))
batch_size = cfg.training.batch_size
num_batches = (train_size + batch_size - 1) // batch_size
expected_steps = [expected_start_step + i for i in range(num_batches)]
assert steps == expected_steps, f"Expected steps {expected_steps}, got {steps}"