[Port HIL_SERL] Final fixes for the Reward Classifier (#598)

This commit is contained in:
Eugene Mironov 2025-01-06 17:34:00 +07:00 committed by Michel Aractingi
parent 13441f0d98
commit 844bfcf484
11 changed files with 59 additions and 19 deletions

View File

@ -4,7 +4,6 @@ from typing import Optional
import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from transformers import AutoImageProcessor, AutoModel
from .configuration_classifier import ClassifierConfig
@ -44,6 +43,8 @@ class Classifier(
name = "classifier"
def __init__(self, config: ClassifierConfig):
from transformers import AutoImageProcessor, AutoModel
super().__init__()
self.config = config
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)

View File

@ -327,7 +327,6 @@ class Critic(nn.Module):
value = self.output_layer(x)
return value.squeeze(-1)
class Policy(nn.Module):
def __init__(
self,

View File

@ -362,12 +362,16 @@ def sanity_check_dataset_name(repo_id, policy):
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None
) -> None:
features_from_robot = get_features_from_robot(robot, use_videos)
if extra_features is not None:
features_from_robot.update(extra_features)
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
("features", dataset.features, get_features_from_robot(robot, use_videos)),
("features", dataset.features, features_from_robot),
]
mismatches = []

View File

@ -39,7 +39,6 @@ policy:
wandb:
enable: false
project: "classifier-training"
entity: "wandb_entity"
job_name: "classifier_training_0"
disable_artifact: false

View File

@ -246,7 +246,7 @@ def record(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video, extra_features)
else:
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)

View File

@ -183,8 +183,14 @@ def record(
resume: bool = False,
local_files_only: bool = False,
run_compute_stats: bool = True,
assign_rewards: bool = False,
) -> LeRobotDataset:
# Load pretrained policy
extra_features = (
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
)
policy = None
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
@ -197,7 +203,7 @@ def record(
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
# initialize listener before sim env
listener, events = init_keyboard_listener()
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
# create sim env
env = env()
@ -237,6 +243,7 @@ def record(
}
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
features = {**features, **extra_features}
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
@ -288,6 +295,13 @@ def record(
"timestamp": env_timestamp,
}
# Overwrite environment reward with manually assigned reward
if assign_rewards:
frame["next.reward"] = events["next.reward"]
# Should success always be false to match what we do in control_utils?
frame["next.success"] = False
for key in image_keys:
if not key.startswith("observation.image"):
frame["observation.image." + key] = observation[key]
@ -472,6 +486,13 @@ if __name__ == "__main__":
default=0,
help="Resume recording on an existing dataset.",
)
parser_record.add_argument(
"--assign-rewards",
type=int,
default=0,
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"

View File

@ -45,7 +45,7 @@ from lerobot.common.utils.utils import (
)
def get_model(cfg, logger):
def get_model(cfg, logger): # noqa I001
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
model = Classifier(classifier_config)
if cfg.resume:
@ -64,6 +64,12 @@ def create_balanced_sampler(dataset, cfg):
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
# Check if the device supports AMP
# Here is an example of the issue that says that MPS doesn't support AMP properply
return cfg.training.use_amp and device.type in ("cuda", "cpu")
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
# Single epoch training loop with AMP support and progress tracking
model.train()
@ -77,7 +83,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
outputs = model(images)
loss = criterion(outputs.logits, labels)
@ -119,7 +125,10 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
samples = []
running_loss = 0
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
):
for batch in tqdm(val_loader, desc="Validation"):
images = batch[cfg.training.image_key].to(device)
labels = batch[cfg.training.label_key].float().to(device)
@ -170,7 +179,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
return accuracy, eval_info
@hydra.main(version_base="1.2", config_path="../configs", config_name="hilserl_classifier")
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training
logging.info(OmegaConf.to_yaml(cfg))

2
poetry.lock generated
View File

@ -7720,4 +7720,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "b9d299916ced6af1d243f961a32b0a4aacbef18e0b95337a5224e8511f5d6dda"
content-hash = "44c74163e398e8ff16973957f69a47bb09b789e92ac4d8fb3ab268defab96427"

View File

@ -71,8 +71,8 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
pyserial = {version = ">=3.5", optional = true}
jsonlines = ">=4.0.0"
transformers = {version = "^4.47.0", optional = true}
torchmetrics = {version = "^1.6.0", optional = true}
transformers = {version = ">=4.47.0", optional = true}
torchmetrics = {version = ">=1.6.0", optional = true}
[tool.poetry.extras]

View File

@ -1,7 +1,6 @@
import torch
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
ClassifierConfig,
ClassifierOutput,
)
@ -21,6 +20,8 @@ def test_classifier_output():
@require_package("transformers")
def test_binary_classifier_with_default_params():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
config = ClassifierConfig()
classifier = Classifier(config)
@ -40,6 +41,8 @@ def test_binary_classifier_with_default_params():
@require_package("transformers")
def test_multiclass_classifier():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
num_classes = 5
config = ClassifierConfig(num_classes=num_classes)
classifier = Classifier(config)
@ -60,6 +63,8 @@ def test_multiclass_classifier():
@require_package("transformers")
def test_default_device():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
config = ClassifierConfig()
assert config.device == "cpu"
@ -70,6 +75,8 @@ def test_default_device():
@require_package("transformers")
def test_explicit_device_setup():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
config = ClassifierConfig(device="meta")
assert config.device == "meta"

View File

@ -151,9 +151,9 @@ def test_validate():
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
@patch("lerobot.scripts.train_hilserl_classifier.make_policy")
@patch("lerobot.scripts.train_hilserl_classifier.get_model")
def test_resume_function(
mock_make_policy,
mock_get_model,
mock_dataset,
mock_logger,
mock_get_last_pretrained_model_dir,
@ -168,7 +168,7 @@ def test_resume_function(
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
cfg = compose(
config_name="reward_classifier",
config_name="hilserl_classifier",
overrides=[
"device=cpu",
"seed=42",
@ -211,7 +211,7 @@ def test_resume_function(
# Instantiate the model and set make_policy to return it
model = make_dummy_model()
mock_make_policy.return_value = model
mock_get_model.return_value = model
# Call train
train(cfg)