Added normalization schemes and style checks

This commit is contained in:
Michel Aractingi 2024-12-29 12:51:21 +00:00
parent b0e2fcdba7
commit 04da4dd3e3
8 changed files with 23 additions and 21 deletions

View File

@ -25,13 +25,13 @@ from glob import glob
from pathlib import Path
import torch
import wandb
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
import wandb
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state

View File

@ -2,8 +2,6 @@ import json
import os
from dataclasses import asdict, dataclass
import torch
@dataclass
class ClassifierConfig:

View File

@ -23,9 +23,11 @@ class ClassifierOutput:
self.hidden_states = hidden_states
def __repr__(self):
return (f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})")
return (
f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})"
)
class Classifier(

View File

@ -95,11 +95,13 @@ def make_optimizer_and_scheduler(cfg, policy):
lr_scheduler = None
elif policy.name == "sac":
optimizer = torch.optim.Adam([
{'params': policy.actor.parameters(), 'lr': policy.config.actor_lr},
{'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr},
{'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr},
])
optimizer = torch.optim.Adam(
[
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
]
)
lr_scheduler = None
elif cfg.policy.name == "vqbet":

View File

@ -22,7 +22,6 @@ from pprint import pformat
import hydra
import torch
import torch.nn as nn
import wandb
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger