Added normalization schemes and style checks
This commit is contained in:
parent
b0e2fcdba7
commit
04da4dd3e3
|
@ -25,13 +25,13 @@ from glob import glob
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import wandb
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
import wandb
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
|
||||
|
|
|
@ -2,8 +2,6 @@ import json
|
|||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassifierConfig:
|
||||
|
|
|
@ -23,9 +23,11 @@ class ClassifierOutput:
|
|||
self.hidden_states = hidden_states
|
||||
|
||||
def __repr__(self):
|
||||
return (f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})")
|
||||
return (
|
||||
f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})"
|
||||
)
|
||||
|
||||
|
||||
class Classifier(
|
||||
|
|
|
@ -95,11 +95,13 @@ def make_optimizer_and_scheduler(cfg, policy):
|
|||
lr_scheduler = None
|
||||
|
||||
elif policy.name == "sac":
|
||||
optimizer = torch.optim.Adam([
|
||||
{'params': policy.actor.parameters(), 'lr': policy.config.actor_lr},
|
||||
{'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr},
|
||||
{'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr},
|
||||
])
|
||||
optimizer = torch.optim.Adam(
|
||||
[
|
||||
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
||||
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
|
||||
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
|
||||
]
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
elif cfg.policy.name == "vqbet":
|
||||
|
|
|
@ -22,7 +22,6 @@ from pprint import pformat
|
|||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
|
@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler
|
|||
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger
|
||||
|
|
Loading…
Reference in New Issue