Added normalization schemes and style checks
This commit is contained in:
parent
3b07766c33
commit
cc85bca2b5
|
@ -25,13 +25,13 @@ from glob import glob
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import wandb
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
import wandb
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
|
||||
|
|
|
@ -2,8 +2,6 @@ import json
|
|||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassifierConfig:
|
||||
|
|
|
@ -23,9 +23,11 @@ class ClassifierOutput:
|
|||
self.hidden_states = hidden_states
|
||||
|
||||
def __repr__(self):
|
||||
return (f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})")
|
||||
return (
|
||||
f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})"
|
||||
)
|
||||
|
||||
|
||||
class Classifier(
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -30,14 +30,36 @@ class SACConfig:
|
|||
critic_target_update_weight = 0.005
|
||||
utd_ratio = 2
|
||||
critic_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
actor_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
policy_kwargs = {
|
||||
"tanh_squash_distribution": True,
|
||||
"std_parameterization": "uniform",
|
||||
"tanh_squash_distribution": True,
|
||||
"std_parameterization": "uniform",
|
||||
}
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [4],
|
||||
}
|
||||
)
|
||||
|
||||
state_encoder_hidden_dim: int = 256
|
||||
latent_dim: int = 256
|
||||
network_hidden_dims: int = 256
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] | None = None
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"},
|
||||
)
|
||||
|
|
|
@ -40,11 +40,9 @@ class SACPolicy(
|
|||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "RL", "SAC"],
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
|
@ -67,10 +65,7 @@ class SACPolicy(
|
|||
# Define networks
|
||||
critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
critic_net = Critic(
|
||||
encoder=encoder,
|
||||
network=MLP(**config.critic_network_kwargs)
|
||||
)
|
||||
critic_net = Critic(encoder=encoder, network=MLP(**config.critic_network_kwargs))
|
||||
critic_nets.append(critic_net)
|
||||
|
||||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||
|
@ -80,10 +75,10 @@ class SACPolicy(
|
|||
encoder=encoder,
|
||||
network=MLP(**config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
**config.policy_kwargs
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
|
||||
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
||||
|
||||
def reset(self):
|
||||
|
@ -103,7 +98,7 @@ class SACPolicy(
|
|||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
actions, _ = self.actor_network(batch['observations'])###
|
||||
actions, _ = self.actor_network(batch["observations"]) ###
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
@ -128,7 +123,6 @@ class SACPolicy(
|
|||
# from HIL-SERL code base
|
||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
||||
|
||||
|
||||
# calculate critics loss
|
||||
# 1- compute actions from policy
|
||||
action_preds, log_probs = self.actor_network(observations)
|
||||
|
@ -137,7 +131,7 @@ class SACPolicy(
|
|||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[:self.config.num_subsample_critics]
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
|
@ -152,7 +146,8 @@ class SACPolicy(
|
|||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
(
|
||||
F.mse_loss(
|
||||
q_preds,
|
||||
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
|
||||
reduction="none",
|
||||
|
@ -163,15 +158,17 @@ class SACPolicy(
|
|||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
).sum(0).mean()
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
|
||||
# calculate actors loss
|
||||
# 1- temperature
|
||||
temperature = self.temperature()
|
||||
|
||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||
actions, log_probs = self.actor_network(observations) \
|
||||
|
||||
actions, log_probs = self.actor_network(observations)
|
||||
# 3- get q-value predictions
|
||||
with torch.no_grad():
|
||||
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
||||
|
@ -181,36 +178,31 @@ class SACPolicy(
|
|||
* ~batch["action_is_pad"]
|
||||
).mean()
|
||||
|
||||
|
||||
# calculate temperature loss
|
||||
# 1- calculate entropy
|
||||
entropy = -log_probs.mean()
|
||||
temperature_loss = self.temp(
|
||||
lhs=entropy,
|
||||
rhs=self.config.target_entropy
|
||||
)
|
||||
temperature_loss = self.temp(lhs=entropy, rhs=self.config.target_entropy)
|
||||
|
||||
loss = critics_loss + actor_loss + temperature_loss
|
||||
|
||||
return {
|
||||
"critics_loss": critics_loss.item(),
|
||||
"actor_loss": actor_loss.item(),
|
||||
"temperature_loss": temperature_loss.item(),
|
||||
"temperature": temperature.item(),
|
||||
"entropy": entropy.item(),
|
||||
"loss": loss,
|
||||
|
||||
}
|
||||
"critics_loss": critics_loss.item(),
|
||||
"actor_loss": actor_loss.item(),
|
||||
"temperature_loss": temperature_loss.item(),
|
||||
"temperature": temperature.item(),
|
||||
"entropy": entropy.item(),
|
||||
"loss": loss,
|
||||
}
|
||||
|
||||
def update(self):
|
||||
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
||||
# TODO: implement UTD update
|
||||
# First update only critics for utd_ratio-1 times
|
||||
#for critic_step in range(self.config.utd_ratio - 1):
|
||||
# only update critic and critic target
|
||||
# for critic_step in range(self.config.utd_ratio - 1):
|
||||
# only update critic and critic target
|
||||
# Then update critic, critic target, actor and temperature
|
||||
|
||||
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
||||
# for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
||||
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
||||
|
||||
|
||||
|
@ -227,13 +219,17 @@ class MLP(nn.Module):
|
|||
layers = []
|
||||
|
||||
for i, size in enumerate(config.network_hidden_dims):
|
||||
layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size))
|
||||
layers.append(
|
||||
nn.Linear(config.network_hidden_dims[i - 1] if i > 0 else config.network_hidden_dims[0], size)
|
||||
)
|
||||
|
||||
if i + 1 < len(config.network_hidden_dims) or activate_final:
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(size))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
layers.append(
|
||||
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
|
@ -250,7 +246,7 @@ class Critic(nn.Module):
|
|||
network: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
activate_final: bool = False,
|
||||
device: str = "cuda"
|
||||
device: str = "cuda",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
|
@ -276,12 +272,7 @@ class Critic(nn.Module):
|
|||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
train: bool = False
|
||||
) -> torch.Tensor:
|
||||
def forward(self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False) -> torch.Tensor:
|
||||
self.train(train)
|
||||
|
||||
observations = observations.to(self.device)
|
||||
|
@ -295,10 +286,7 @@ class Critic(nn.Module):
|
|||
return value.squeeze(-1)
|
||||
|
||||
def q_value_ensemble(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
train: bool = False
|
||||
self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False
|
||||
) -> torch.Tensor:
|
||||
observations = observations.to(self.device)
|
||||
actions = actions.to(self.device)
|
||||
|
@ -327,7 +315,7 @@ class Policy(nn.Module):
|
|||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
activate_final: bool = False,
|
||||
device: str = "cuda"
|
||||
device: str = "cuda",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
|
@ -374,7 +362,7 @@ class Policy(nn.Module):
|
|||
observations: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
train: bool = False,
|
||||
non_squash_distribution: bool = False
|
||||
non_squash_distribution: bool = False,
|
||||
) -> torch.distributions.Distribution:
|
||||
self.train(train)
|
||||
|
||||
|
@ -398,9 +386,7 @@ class Policy(nn.Module):
|
|||
elif self.std_parameterization == "uniform":
|
||||
stds = torch.exp(self.log_stds).expand_as(means)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid std_parameterization: {self.std_parameterization}"
|
||||
)
|
||||
raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}")
|
||||
else:
|
||||
assert self.std_parameterization == "fixed"
|
||||
stds = self.fixed_std.expand_as(means)
|
||||
|
@ -506,12 +492,7 @@ class SACObservationEncoder(nn.Module):
|
|||
|
||||
|
||||
class LagrangeMultiplier(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
init_value: float = 1.0,
|
||||
constraint_shape: Sequence[int] = (),
|
||||
device: str = "cuda"
|
||||
):
|
||||
def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
|
||||
|
@ -523,11 +504,7 @@ class LagrangeMultiplier(nn.Module):
|
|||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
lhs: Optional[torch.Tensor] = None,
|
||||
rhs: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
def forward(self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# Get the multiplier value based on parameterization
|
||||
multiplier = torch.nn.functional.softplus(self.lagrange)
|
||||
|
||||
|
@ -578,17 +555,11 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
|||
# Add rescaling transform if bounds are provided
|
||||
if low is not None and high is not None:
|
||||
transforms.append(
|
||||
torch.distributions.transforms.AffineTransform(
|
||||
loc=(high + low) / 2,
|
||||
scale=(high - low) / 2
|
||||
)
|
||||
torch.distributions.transforms.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2)
|
||||
)
|
||||
|
||||
# Initialize parent class
|
||||
super().__init__(
|
||||
base_distribution=base_distribution,
|
||||
transforms=transforms
|
||||
)
|
||||
super().__init__(base_distribution=base_distribution, transforms=transforms)
|
||||
|
||||
# Store parameters
|
||||
self.loc = loc
|
||||
|
@ -691,4 +662,3 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
|||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
|
||||
|
|
|
@ -53,6 +53,70 @@ from lerobot.configs.train import TrainPipelineConfig
|
|||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def make_optimizer_and_scheduler(cfg, policy):
|
||||
if cfg.policy.name == "act":
|
||||
optimizer_params_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in policy.named_parameters()
|
||||
if not n.startswith("model.backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in policy.named_parameters()
|
||||
if n.startswith("model.backbone") and p.requires_grad
|
||||
],
|
||||
"lr": cfg.training.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
|
||||
)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "diffusion":
|
||||
optimizer = torch.optim.Adam(
|
||||
policy.diffusion.parameters(),
|
||||
cfg.training.lr,
|
||||
cfg.training.adam_betas,
|
||||
cfg.training.adam_eps,
|
||||
cfg.training.adam_weight_decay,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
cfg.training.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||
num_training_steps=cfg.training.offline_steps,
|
||||
)
|
||||
elif policy.name == "tdmpc":
|
||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
|
||||
elif policy.name == "sac":
|
||||
optimizer = torch.optim.Adam(
|
||||
[
|
||||
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
||||
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
|
||||
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
|
||||
]
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||
|
||||
optimizer = VQBeTOptimizer(policy, cfg)
|
||||
lr_scheduler = VQBeTScheduler(optimizer, cfg)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
def update_policy(
|
||||
train_metrics: MetricsTracker,
|
||||
policy: PreTrainedPolicy,
|
||||
|
|
|
@ -22,7 +22,6 @@ from pprint import pformat
|
|||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
|
@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler
|
|||
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger
|
||||
|
|
Loading…
Reference in New Issue