From 844bfcf484666f265bb13838504634f32f1a69c0 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Mon, 6 Jan 2025 17:34:00 +0700 Subject: [PATCH] [Port HIL_SERL] Final fixes for the Reward Classifier (#598) --- .../hilserl/classifier/modeling_classifier.py | 3 ++- lerobot/common/policies/sac/modeling_sac.py | 1 - lerobot/common/robot_devices/control_utils.py | 8 +++++-- .../configs/policy/hilserl_classifier.yaml | 1 - lerobot/scripts/control_robot.py | 2 +- lerobot/scripts/control_sim_robot.py | 23 ++++++++++++++++++- lerobot/scripts/train_hilserl_classifier.py | 17 ++++++++++---- poetry.lock | 2 +- pyproject.toml | 4 ++-- .../classifier/test_modelling_classifier.py | 9 +++++++- tests/test_train_hilserl_classifier.py | 8 +++---- 11 files changed, 59 insertions(+), 19 deletions(-) diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index 28b05744..d7bd42cd 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -4,7 +4,6 @@ from typing import Optional import torch from huggingface_hub import PyTorchModelHubMixin from torch import Tensor, nn -from transformers import AutoImageProcessor, AutoModel from .configuration_classifier import ClassifierConfig @@ -44,6 +43,8 @@ class Classifier( name = "classifier" def __init__(self, config: ClassifierConfig): + from transformers import AutoImageProcessor, AutoModel + super().__init__() self.config = config self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index aada2714..bf357cae 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -327,7 +327,6 @@ class Critic(nn.Module): value = self.output_layer(x) return value.squeeze(-1) - class Policy(nn.Module): def __init__( self, diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 8a6bcfbd..ad6f5632 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -362,12 +362,16 @@ def sanity_check_dataset_name(repo_id, policy): def sanity_check_dataset_robot_compatibility( - dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool + dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None ) -> None: + features_from_robot = get_features_from_robot(robot, use_videos) + if extra_features is not None: + features_from_robot.update(extra_features) + fields = [ ("robot_type", dataset.meta.robot_type, robot.robot_type), ("fps", dataset.fps, fps), - ("features", dataset.features, get_features_from_robot(robot, use_videos)), + ("features", dataset.features, features_from_robot), ] mismatches = [] diff --git a/lerobot/configs/policy/hilserl_classifier.yaml b/lerobot/configs/policy/hilserl_classifier.yaml index be82bc4e..498c9983 100644 --- a/lerobot/configs/policy/hilserl_classifier.yaml +++ b/lerobot/configs/policy/hilserl_classifier.yaml @@ -39,7 +39,6 @@ policy: wandb: enable: false project: "classifier-training" - entity: "wandb_entity" job_name: "classifier_training_0" disable_artifact: false diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 6286244a..22529653 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -246,7 +246,7 @@ def record( num_processes=num_image_writer_processes, num_threads=num_image_writer_threads_per_camera * len(robot.cameras), ) - sanity_check_dataset_robot_compatibility(dataset, robot, fps, video) + sanity_check_dataset_robot_compatibility(dataset, robot, fps, video, extra_features) else: # Create empty dataset or load existing saved episodes sanity_check_dataset_name(repo_id, policy) diff --git a/lerobot/scripts/control_sim_robot.py b/lerobot/scripts/control_sim_robot.py index 4fffa8c7..67bdfb85 100644 --- a/lerobot/scripts/control_sim_robot.py +++ b/lerobot/scripts/control_sim_robot.py @@ -183,8 +183,14 @@ def record( resume: bool = False, local_files_only: bool = False, run_compute_stats: bool = True, + assign_rewards: bool = False, ) -> LeRobotDataset: # Load pretrained policy + + extra_features = ( + {"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None + ) + policy = None if pretrained_policy_name_or_path is not None: policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) @@ -197,7 +203,7 @@ def record( raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.") # initialize listener before sim env - listener, events = init_keyboard_listener() + listener, events = init_keyboard_listener(assign_rewards=assign_rewards) # create sim env env = env() @@ -237,6 +243,7 @@ def record( } features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None} + features = {**features, **extra_features} # Create empty dataset or load existing saved episodes sanity_check_dataset_name(repo_id, policy) @@ -288,6 +295,13 @@ def record( "timestamp": env_timestamp, } + # Overwrite environment reward with manually assigned reward + if assign_rewards: + frame["next.reward"] = events["next.reward"] + + # Should success always be false to match what we do in control_utils? + frame["next.success"] = False + for key in image_keys: if not key.startswith("observation.image"): frame["observation.image." + key] = observation[key] @@ -472,6 +486,13 @@ if __name__ == "__main__": default=0, help="Resume recording on an existing dataset.", ) + parser_record.add_argument( + "--assign-rewards", + type=int, + default=0, + help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.", + ) + parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay.add_argument( "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index ea8336a9..22ff2957 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -45,7 +45,7 @@ from lerobot.common.utils.utils import ( ) -def get_model(cfg, logger): +def get_model(cfg, logger): # noqa I001 classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) model = Classifier(classifier_config) if cfg.resume: @@ -64,6 +64,12 @@ def create_balanced_sampler(dataset, cfg): return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True) +def support_amp(device: torch.device, cfg: DictConfig) -> bool: + # Check if the device supports AMP + # Here is an example of the issue that says that MPS doesn't support AMP properply + return cfg.training.use_amp and device.type in ("cuda", "cpu") + + def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg): # Single epoch training loop with AMP support and progress tracking model.train() @@ -77,7 +83,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, labels = batch[cfg.training.label_key].float().to(device) # Forward pass with optional AMP - with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext(): + with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(): outputs = model(images) loss = criterion(outputs.logits, labels) @@ -119,7 +125,10 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l samples = [] running_loss = 0 - with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext(): + with ( + torch.no_grad(), + torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(), + ): for batch in tqdm(val_loader, desc="Validation"): images = batch[cfg.training.image_key].to(device) labels = batch[cfg.training.label_key].float().to(device) @@ -170,7 +179,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l return accuracy, eval_info -@hydra.main(version_base="1.2", config_path="../configs", config_name="hilserl_classifier") +@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier") def train(cfg: DictConfig) -> None: # Main training pipeline with support for resuming training logging.info(OmegaConf.to_yaml(cfg)) diff --git a/poetry.lock b/poetry.lock index 919edd18..81462fe8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7720,4 +7720,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "b9d299916ced6af1d243f961a32b0a4aacbef18e0b95337a5224e8511f5d6dda" +content-hash = "44c74163e398e8ff16973957f69a47bb09b789e92ac4d8fb3ab268defab96427" diff --git a/pyproject.toml b/pyproject.toml index 738903bd..05ab921a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,8 +71,8 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true} pyserial = {version = ">=3.5", optional = true} jsonlines = ">=4.0.0" -transformers = {version = "^4.47.0", optional = true} -torchmetrics = {version = "^1.6.0", optional = true} +transformers = {version = ">=4.47.0", optional = true} +torchmetrics = {version = ">=1.6.0", optional = true} [tool.poetry.extras] diff --git a/tests/policies/hilserl/classifier/test_modelling_classifier.py b/tests/policies/hilserl/classifier/test_modelling_classifier.py index 014165eb..a3db4211 100644 --- a/tests/policies/hilserl/classifier/test_modelling_classifier.py +++ b/tests/policies/hilserl/classifier/test_modelling_classifier.py @@ -1,7 +1,6 @@ import torch from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( - Classifier, ClassifierConfig, ClassifierOutput, ) @@ -21,6 +20,8 @@ def test_classifier_output(): @require_package("transformers") def test_binary_classifier_with_default_params(): + from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + config = ClassifierConfig() classifier = Classifier(config) @@ -40,6 +41,8 @@ def test_binary_classifier_with_default_params(): @require_package("transformers") def test_multiclass_classifier(): + from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + num_classes = 5 config = ClassifierConfig(num_classes=num_classes) classifier = Classifier(config) @@ -60,6 +63,8 @@ def test_multiclass_classifier(): @require_package("transformers") def test_default_device(): + from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + config = ClassifierConfig() assert config.device == "cpu" @@ -70,6 +75,8 @@ def test_default_device(): @require_package("transformers") def test_explicit_device_setup(): + from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + config = ClassifierConfig(device="meta") assert config.device == "meta" diff --git a/tests/test_train_hilserl_classifier.py b/tests/test_train_hilserl_classifier.py index 66d8fbe4..c1d854ac 100644 --- a/tests/test_train_hilserl_classifier.py +++ b/tests/test_train_hilserl_classifier.py @@ -151,9 +151,9 @@ def test_validate(): @patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir") @patch("lerobot.scripts.train_hilserl_classifier.Logger") @patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset") -@patch("lerobot.scripts.train_hilserl_classifier.make_policy") +@patch("lerobot.scripts.train_hilserl_classifier.get_model") def test_resume_function( - mock_make_policy, + mock_get_model, mock_dataset, mock_logger, mock_get_last_pretrained_model_dir, @@ -168,7 +168,7 @@ def test_resume_function( with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"): cfg = compose( - config_name="reward_classifier", + config_name="hilserl_classifier", overrides=[ "device=cpu", "seed=42", @@ -211,7 +211,7 @@ def test_resume_function( # Instantiate the model and set make_policy to return it model = make_dummy_model() - mock_make_policy.return_value = model + mock_get_model.return_value = model # Call train train(cfg)