Added normalization schemes and style checks

This commit is contained in:
Michel Aractingi 2024-12-29 12:51:21 +00:00
parent b0e2fcdba7
commit 04da4dd3e3
8 changed files with 23 additions and 21 deletions

View File

@ -25,13 +25,13 @@ from glob import glob
from pathlib import Path from pathlib import Path
import torch import torch
import wandb
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
import wandb
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state from lerobot.common.utils.utils import get_global_random_state, set_global_random_state

View File

@ -2,8 +2,6 @@ import json
import os import os
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
import torch
@dataclass @dataclass
class ClassifierConfig: class ClassifierConfig:

View File

@ -23,9 +23,11 @@ class ClassifierOutput:
self.hidden_states = hidden_states self.hidden_states = hidden_states
def __repr__(self): def __repr__(self):
return (f"ClassifierOutput(logits={self.logits}, " return (
f"probabilities={self.probabilities}, " f"ClassifierOutput(logits={self.logits}, "
f"hidden_states={self.hidden_states})") f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})"
)
class Classifier( class Classifier(
@ -74,7 +76,7 @@ class Classifier(
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
else: else:
raise ValueError("Unsupported CNN architecture") raise ValueError("Unsupported CNN architecture")
self.encoder = self.encoder.to(self.config.device) self.encoder = self.encoder.to(self.config.device)
def _freeze_encoder(self) -> None: def _freeze_encoder(self) -> None:

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. # Copyright 2024 The HuggingFace Inc. team.
# All rights reserved. # All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. # Copyright 2024 The HuggingFace Inc. team.
# All rights reserved. # All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -26,4 +26,4 @@ class HILSerlPolicy(
repo_url="https://github.com/huggingface/lerobot", repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "hilserl"], tags=["robotics", "hilserl"],
): ):
pass pass

View File

@ -24,7 +24,7 @@ python lerobot/scripts/eval_on_robot.py \
``` ```
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared **NOTE** (michel-aractingi): This script is incomplete and it is being prepared
for running training on the real robot. for running training on the real robot.
""" """
import argparse import argparse
@ -47,7 +47,7 @@ from lerobot.common.utils.utils import (
def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict: def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict:
"""Run a batched policy rollout on the real robot. """Run a batched policy rollout on the real robot.
The return dictionary contains: The return dictionary contains:
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation "robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
@ -64,7 +64,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
extraneous elements from the sequences above. extraneous elements from the sequences above.
Args: Args:
robot: The robot class that defines the interface with the real robot. robot: The robot class that defines the interface with the real robot.
policy: The policy. Must be a PyTorch nn module. policy: The policy. Must be a PyTorch nn module.
Returns: Returns:
@ -77,7 +77,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
listener, events = init_keyboard_listener() listener, events = init_keyboard_listener()
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready. # Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
# policy.reset() # policy.reset()
# Get observation from real robot # Get observation from real robot
observation = robot.capture_observation() observation = robot.capture_observation()

View File

@ -95,12 +95,14 @@ def make_optimizer_and_scheduler(cfg, policy):
lr_scheduler = None lr_scheduler = None
elif policy.name == "sac": elif policy.name == "sac":
optimizer = torch.optim.Adam([ optimizer = torch.optim.Adam(
{'params': policy.actor.parameters(), 'lr': policy.config.actor_lr}, [
{'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr}, {"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr}, {"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
]) {"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
lr_scheduler = None ]
)
lr_scheduler = None
elif cfg.policy.name == "vqbet": elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler

View File

@ -22,7 +22,6 @@ from pprint import pformat
import hydra import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
import wandb
from deepdiff import DeepDiff from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger from lerobot.common.logger import Logger