Added normalization schemes and style checks

This commit is contained in:
Michel Aractingi 2024-12-29 12:51:21 +00:00
parent b0e2fcdba7
commit 04da4dd3e3
8 changed files with 23 additions and 21 deletions

View File

@ -25,13 +25,13 @@ from glob import glob
from pathlib import Path from pathlib import Path
import torch import torch
import wandb
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
import wandb
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state from lerobot.common.utils.utils import get_global_random_state, set_global_random_state

View File

@ -2,8 +2,6 @@ import json
import os import os
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
import torch
@dataclass @dataclass
class ClassifierConfig: class ClassifierConfig:

View File

@ -23,9 +23,11 @@ class ClassifierOutput:
self.hidden_states = hidden_states self.hidden_states = hidden_states
def __repr__(self): def __repr__(self):
return (f"ClassifierOutput(logits={self.logits}, " return (
f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, " f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})") f"hidden_states={self.hidden_states})"
)
class Classifier( class Classifier(

View File

@ -95,11 +95,13 @@ def make_optimizer_and_scheduler(cfg, policy):
lr_scheduler = None lr_scheduler = None
elif policy.name == "sac": elif policy.name == "sac":
optimizer = torch.optim.Adam([ optimizer = torch.optim.Adam(
{'params': policy.actor.parameters(), 'lr': policy.config.actor_lr}, [
{'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr}, {"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr}, {"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
]) {"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
]
)
lr_scheduler = None lr_scheduler = None
elif cfg.policy.name == "vqbet": elif cfg.policy.name == "vqbet":

View File

@ -22,7 +22,6 @@ from pprint import pformat
import hydra import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
import wandb
from deepdiff import DeepDiff from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger from lerobot.common.logger import Logger