[Port HIL_SERL] Final fixes for the Reward Classifier (#598)

This commit is contained in:
Eugene Mironov 2025-01-06 17:34:00 +07:00 committed by Michel Aractingi
parent 13441f0d98
commit 844bfcf484
11 changed files with 59 additions and 19 deletions

View File

@ -4,7 +4,6 @@ from typing import Optional
import torch import torch
from huggingface_hub import PyTorchModelHubMixin from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn from torch import Tensor, nn
from transformers import AutoImageProcessor, AutoModel
from .configuration_classifier import ClassifierConfig from .configuration_classifier import ClassifierConfig
@ -44,6 +43,8 @@ class Classifier(
name = "classifier" name = "classifier"
def __init__(self, config: ClassifierConfig): def __init__(self, config: ClassifierConfig):
from transformers import AutoImageProcessor, AutoModel
super().__init__() super().__init__()
self.config = config self.config = config
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True) self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)

View File

@ -327,7 +327,6 @@ class Critic(nn.Module):
value = self.output_layer(x) value = self.output_layer(x)
return value.squeeze(-1) return value.squeeze(-1)
class Policy(nn.Module): class Policy(nn.Module):
def __init__( def __init__(
self, self,

View File

@ -362,12 +362,16 @@ def sanity_check_dataset_name(repo_id, policy):
def sanity_check_dataset_robot_compatibility( def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None
) -> None: ) -> None:
features_from_robot = get_features_from_robot(robot, use_videos)
if extra_features is not None:
features_from_robot.update(extra_features)
fields = [ fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type), ("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps), ("fps", dataset.fps, fps),
("features", dataset.features, get_features_from_robot(robot, use_videos)), ("features", dataset.features, features_from_robot),
] ]
mismatches = [] mismatches = []

View File

@ -39,7 +39,6 @@ policy:
wandb: wandb:
enable: false enable: false
project: "classifier-training" project: "classifier-training"
entity: "wandb_entity"
job_name: "classifier_training_0" job_name: "classifier_training_0"
disable_artifact: false disable_artifact: false

View File

@ -246,7 +246,7 @@ def record(
num_processes=num_image_writer_processes, num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * len(robot.cameras), num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
) )
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video) sanity_check_dataset_robot_compatibility(dataset, robot, fps, video, extra_features)
else: else:
# Create empty dataset or load existing saved episodes # Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy) sanity_check_dataset_name(repo_id, policy)

View File

@ -183,8 +183,14 @@ def record(
resume: bool = False, resume: bool = False,
local_files_only: bool = False, local_files_only: bool = False,
run_compute_stats: bool = True, run_compute_stats: bool = True,
assign_rewards: bool = False,
) -> LeRobotDataset: ) -> LeRobotDataset:
# Load pretrained policy # Load pretrained policy
extra_features = (
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
)
policy = None policy = None
if pretrained_policy_name_or_path is not None: if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
@ -197,7 +203,7 @@ def record(
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.") raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
# initialize listener before sim env # initialize listener before sim env
listener, events = init_keyboard_listener() listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
# create sim env # create sim env
env = env() env = env()
@ -237,6 +243,7 @@ def record(
} }
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None} features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
features = {**features, **extra_features}
# Create empty dataset or load existing saved episodes # Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy) sanity_check_dataset_name(repo_id, policy)
@ -288,6 +295,13 @@ def record(
"timestamp": env_timestamp, "timestamp": env_timestamp,
} }
# Overwrite environment reward with manually assigned reward
if assign_rewards:
frame["next.reward"] = events["next.reward"]
# Should success always be false to match what we do in control_utils?
frame["next.success"] = False
for key in image_keys: for key in image_keys:
if not key.startswith("observation.image"): if not key.startswith("observation.image"):
frame["observation.image." + key] = observation[key] frame["observation.image." + key] = observation[key]
@ -472,6 +486,13 @@ if __name__ == "__main__":
default=0, default=0,
help="Resume recording on an existing dataset.", help="Resume recording on an existing dataset.",
) )
parser_record.add_argument(
"--assign-rewards",
type=int,
default=0,
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument( parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"

View File

@ -45,7 +45,7 @@ from lerobot.common.utils.utils import (
) )
def get_model(cfg, logger): def get_model(cfg, logger): # noqa I001
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
model = Classifier(classifier_config) model = Classifier(classifier_config)
if cfg.resume: if cfg.resume:
@ -64,6 +64,12 @@ def create_balanced_sampler(dataset, cfg):
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True) return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
# Check if the device supports AMP
# Here is an example of the issue that says that MPS doesn't support AMP properply
return cfg.training.use_amp and device.type in ("cuda", "cpu")
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg): def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
# Single epoch training loop with AMP support and progress tracking # Single epoch training loop with AMP support and progress tracking
model.train() model.train()
@ -77,7 +83,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
labels = batch[cfg.training.label_key].float().to(device) labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP # Forward pass with optional AMP
with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext(): with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
outputs = model(images) outputs = model(images)
loss = criterion(outputs.logits, labels) loss = criterion(outputs.logits, labels)
@ -119,7 +125,10 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
samples = [] samples = []
running_loss = 0 running_loss = 0
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext(): with (
torch.no_grad(),
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
):
for batch in tqdm(val_loader, desc="Validation"): for batch in tqdm(val_loader, desc="Validation"):
images = batch[cfg.training.image_key].to(device) images = batch[cfg.training.image_key].to(device)
labels = batch[cfg.training.label_key].float().to(device) labels = batch[cfg.training.label_key].float().to(device)
@ -170,7 +179,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
return accuracy, eval_info return accuracy, eval_info
@hydra.main(version_base="1.2", config_path="../configs", config_name="hilserl_classifier") @hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
def train(cfg: DictConfig) -> None: def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training # Main training pipeline with support for resuming training
logging.info(OmegaConf.to_yaml(cfg)) logging.info(OmegaConf.to_yaml(cfg))

2
poetry.lock generated
View File

@ -7720,4 +7720,4 @@ xarm = ["gym-xarm"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "b9d299916ced6af1d243f961a32b0a4aacbef18e0b95337a5224e8511f5d6dda" content-hash = "44c74163e398e8ff16973957f69a47bb09b789e92ac4d8fb3ab268defab96427"

View File

@ -71,8 +71,8 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true} hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
pyserial = {version = ">=3.5", optional = true} pyserial = {version = ">=3.5", optional = true}
jsonlines = ">=4.0.0" jsonlines = ">=4.0.0"
transformers = {version = "^4.47.0", optional = true} transformers = {version = ">=4.47.0", optional = true}
torchmetrics = {version = "^1.6.0", optional = true} torchmetrics = {version = ">=1.6.0", optional = true}
[tool.poetry.extras] [tool.poetry.extras]

View File

@ -1,7 +1,6 @@
import torch import torch
from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
ClassifierConfig, ClassifierConfig,
ClassifierOutput, ClassifierOutput,
) )
@ -21,6 +20,8 @@ def test_classifier_output():
@require_package("transformers") @require_package("transformers")
def test_binary_classifier_with_default_params(): def test_binary_classifier_with_default_params():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
config = ClassifierConfig() config = ClassifierConfig()
classifier = Classifier(config) classifier = Classifier(config)
@ -40,6 +41,8 @@ def test_binary_classifier_with_default_params():
@require_package("transformers") @require_package("transformers")
def test_multiclass_classifier(): def test_multiclass_classifier():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
num_classes = 5 num_classes = 5
config = ClassifierConfig(num_classes=num_classes) config = ClassifierConfig(num_classes=num_classes)
classifier = Classifier(config) classifier = Classifier(config)
@ -60,6 +63,8 @@ def test_multiclass_classifier():
@require_package("transformers") @require_package("transformers")
def test_default_device(): def test_default_device():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
config = ClassifierConfig() config = ClassifierConfig()
assert config.device == "cpu" assert config.device == "cpu"
@ -70,6 +75,8 @@ def test_default_device():
@require_package("transformers") @require_package("transformers")
def test_explicit_device_setup(): def test_explicit_device_setup():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
config = ClassifierConfig(device="meta") config = ClassifierConfig(device="meta")
assert config.device == "meta" assert config.device == "meta"

View File

@ -151,9 +151,9 @@ def test_validate():
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir") @patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
@patch("lerobot.scripts.train_hilserl_classifier.Logger") @patch("lerobot.scripts.train_hilserl_classifier.Logger")
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset") @patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
@patch("lerobot.scripts.train_hilserl_classifier.make_policy") @patch("lerobot.scripts.train_hilserl_classifier.get_model")
def test_resume_function( def test_resume_function(
mock_make_policy, mock_get_model,
mock_dataset, mock_dataset,
mock_logger, mock_logger,
mock_get_last_pretrained_model_dir, mock_get_last_pretrained_model_dir,
@ -168,7 +168,7 @@ def test_resume_function(
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"): with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
cfg = compose( cfg = compose(
config_name="reward_classifier", config_name="hilserl_classifier",
overrides=[ overrides=[
"device=cpu", "device=cpu",
"seed=42", "seed=42",
@ -211,7 +211,7 @@ def test_resume_function(
# Instantiate the model and set make_policy to return it # Instantiate the model and set make_policy to return it
model = make_dummy_model() model = make_dummy_model()
mock_make_policy.return_value = model mock_get_model.return_value = model
# Call train # Call train
train(cfg) train(cfg)