From 04da4dd3e36ae8307f598570aec4fe37a28a4f72 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Sun, 29 Dec 2024 12:51:21 +0000 Subject: [PATCH] Added normalization schemes and style checks --- lerobot/common/logger.py | 2 +- .../hilserl/classifier/configuration_classifier.py | 2 -- .../hilserl/classifier/modeling_classifier.py | 10 ++++++---- .../policies/hilserl/configuration_hilserl.py | 2 +- .../common/policies/hilserl/modeling_hilserl.py | 4 ++-- lerobot/scripts/eval_on_robot.py | 8 ++++---- lerobot/scripts/train.py | 14 ++++++++------ lerobot/scripts/train_hilserl_classifier.py | 2 +- 8 files changed, 23 insertions(+), 21 deletions(-) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 4015492d..dec8b465 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -25,13 +25,13 @@ from glob import glob from pathlib import Path import torch -import wandb from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from omegaconf import DictConfig, OmegaConf from termcolor import colored from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler +import wandb from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import get_global_random_state, set_global_random_state diff --git a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py index 553e4262..f0b9352f 100644 --- a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py @@ -2,8 +2,6 @@ import json import os from dataclasses import asdict, dataclass -import torch - @dataclass class ClassifierConfig: diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index 0b8d66ac..28b05744 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -23,9 +23,11 @@ class ClassifierOutput: self.hidden_states = hidden_states def __repr__(self): - return (f"ClassifierOutput(logits={self.logits}, " - f"probabilities={self.probabilities}, " - f"hidden_states={self.hidden_states})") + return ( + f"ClassifierOutput(logits={self.logits}, " + f"probabilities={self.probabilities}, " + f"hidden_states={self.hidden_states})" + ) class Classifier( @@ -74,7 +76,7 @@ class Classifier( self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension else: raise ValueError("Unsupported CNN architecture") - + self.encoder = self.encoder.to(self.config.device) def _freeze_encoder(self) -> None: diff --git a/lerobot/common/policies/hilserl/configuration_hilserl.py b/lerobot/common/policies/hilserl/configuration_hilserl.py index f1bc850f..80d2f578 100644 --- a/lerobot/common/policies/hilserl/configuration_hilserl.py +++ b/lerobot/common/policies/hilserl/configuration_hilserl.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2024 The HuggingFace Inc. team. +# Copyright 2024 The HuggingFace Inc. team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/lerobot/common/policies/hilserl/modeling_hilserl.py b/lerobot/common/policies/hilserl/modeling_hilserl.py index 236ed433..679eb010 100644 --- a/lerobot/common/policies/hilserl/modeling_hilserl.py +++ b/lerobot/common/policies/hilserl/modeling_hilserl.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2024 The HuggingFace Inc. team. +# Copyright 2024 The HuggingFace Inc. team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,4 +26,4 @@ class HILSerlPolicy( repo_url="https://github.com/huggingface/lerobot", tags=["robotics", "hilserl"], ): - pass \ No newline at end of file + pass diff --git a/lerobot/scripts/eval_on_robot.py b/lerobot/scripts/eval_on_robot.py index 6a790f0a..92daa860 100644 --- a/lerobot/scripts/eval_on_robot.py +++ b/lerobot/scripts/eval_on_robot.py @@ -24,7 +24,7 @@ python lerobot/scripts/eval_on_robot.py \ ``` **NOTE** (michel-aractingi): This script is incomplete and it is being prepared -for running training on the real robot. +for running training on the real robot. """ import argparse @@ -47,7 +47,7 @@ from lerobot.common.utils.utils import ( def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict: - """Run a batched policy rollout on the real robot. + """Run a batched policy rollout on the real robot. The return dictionary contains: "robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation @@ -64,7 +64,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, extraneous elements from the sequences above. Args: - robot: The robot class that defines the interface with the real robot. + robot: The robot class that defines the interface with the real robot. policy: The policy. Must be a PyTorch nn module. Returns: @@ -77,7 +77,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, listener, events = init_keyboard_listener() # Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready. - # policy.reset() + # policy.reset() # Get observation from real robot observation = robot.capture_observation() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 346c3acd..fbe7927d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -95,12 +95,14 @@ def make_optimizer_and_scheduler(cfg, policy): lr_scheduler = None elif policy.name == "sac": - optimizer = torch.optim.Adam([ - {'params': policy.actor.parameters(), 'lr': policy.config.actor_lr}, - {'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr}, - {'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr}, - ]) - lr_scheduler = None + optimizer = torch.optim.Adam( + [ + {"params": policy.actor.parameters(), "lr": policy.config.actor_lr}, + {"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr}, + {"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr}, + ] + ) + lr_scheduler = None elif cfg.policy.name == "vqbet": from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 78659dc8..ea8336a9 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -22,7 +22,6 @@ from pprint import pformat import hydra import torch import torch.nn as nn -import wandb from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from termcolor import colored @@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler from torch.utils.data import DataLoader, WeightedRandomSampler, random_split from tqdm import tqdm +import wandb from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.logger import Logger