Added normalization schemes and style checks

This commit is contained in:
Michel Aractingi 2024-12-29 12:51:21 +00:00 committed by AdilZouitine
parent 9dafad15e6
commit 80b86e9bc3
10 changed files with 206 additions and 150 deletions

View File

@ -25,13 +25,13 @@ from glob import glob
from pathlib import Path from pathlib import Path
import torch import torch
import wandb
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
import wandb
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state from lerobot.common.utils.utils import get_global_random_state, set_global_random_state

View File

@ -2,8 +2,6 @@ import json
import os import os
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
import torch
@dataclass @dataclass
class ClassifierConfig: class ClassifierConfig:

View File

@ -23,9 +23,11 @@ class ClassifierOutput:
self.hidden_states = hidden_states self.hidden_states = hidden_states
def __repr__(self): def __repr__(self):
return (f"ClassifierOutput(logits={self.logits}, " return (
f"probabilities={self.probabilities}, " f"ClassifierOutput(logits={self.logits}, "
f"hidden_states={self.hidden_states})") f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})"
)
class Classifier( class Classifier(

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass, field
@dataclass @dataclass
@ -30,14 +30,36 @@ class SACConfig:
critic_target_update_weight = 0.005 critic_target_update_weight = 0.005
utd_ratio = 2 utd_ratio = 2
critic_network_kwargs = { critic_network_kwargs = {
"hidden_dims": [256, 256], "hidden_dims": [256, 256],
"activate_final": True, "activate_final": True,
} }
actor_network_kwargs = { actor_network_kwargs = {
"hidden_dims": [256, 256], "hidden_dims": [256, 256],
"activate_final": True, "activate_final": True,
} }
policy_kwargs = { policy_kwargs = {
"tanh_squash_distribution": True, "tanh_squash_distribution": True,
"std_parameterization": "uniform", "std_parameterization": "uniform",
}
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
} }
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
state_encoder_hidden_dim: int = 256
latent_dim: int = 256
network_hidden_dims: int = 256
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)

View File

@ -40,11 +40,9 @@ class SACPolicy(
repo_url="https://github.com/huggingface/lerobot", repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"], tags=["robotics", "RL", "SAC"],
): ):
def __init__( def __init__(
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
): ):
super().__init__() super().__init__()
if config is None: if config is None:
@ -67,10 +65,7 @@ class SACPolicy(
# Define networks # Define networks
critic_nets = [] critic_nets = []
for _ in range(config.num_critics): for _ in range(config.num_critics):
critic_net = Critic( critic_net = Critic(encoder=encoder, network=MLP(**config.critic_network_kwargs))
encoder=encoder,
network=MLP(**config.critic_network_kwargs)
)
critic_nets.append(critic_net) critic_nets.append(critic_net)
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
@ -80,10 +75,10 @@ class SACPolicy(
encoder=encoder, encoder=encoder,
network=MLP(**config.actor_network_kwargs), network=MLP(**config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0], action_dim=config.output_shapes["action"][0],
**config.policy_kwargs **config.policy_kwargs,
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
self.temperature = LagrangeMultiplier(init_value=config.temperature_init) self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
def reset(self): def reset(self):
@ -103,7 +98,7 @@ class SACPolicy(
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
actions, _ = self.actor_network(batch['observations'])### actions, _ = self.actor_network(batch["observations"]) ###
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss. """Run the batch through the model and compute the loss.
@ -128,7 +123,6 @@ class SACPolicy(
# from HIL-SERL code base # from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss # calculate critics loss
# 1- compute actions from policy # 1- compute actions from policy
action_preds, log_probs = self.actor_network(observations) action_preds, log_probs = self.actor_network(observations)
@ -137,7 +131,7 @@ class SACPolicy(
# subsample critics to prevent overfitting if use high UTD (update to date) # subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None: if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics) indices = torch.randperm(self.config.num_critics)
indices = indices[:self.config.num_subsample_critics] indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices] q_targets = q_targets[indices]
# critics subsample size # critics subsample size
@ -152,7 +146,8 @@ class SACPolicy(
# 4- Calculate loss # 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = ( critics_loss = (
F.mse_loss( (
F.mse_loss(
q_preds, q_preds,
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]), einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
reduction="none", reduction="none",
@ -163,15 +158,17 @@ class SACPolicy(
# q_targets depends on the reward and the next observations. # q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"] * ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:] * ~batch["observation.state_is_pad"][1:]
).sum(0).mean() )
.sum(0)
.mean()
)
# calculate actors loss # calculate actors loss
# 1- temperature # 1- temperature
temperature = self.temperature() temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,) # 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor_network(observations) \ actions, log_probs = self.actor_network(observations)
# 3- get q-value predictions # 3- get q-value predictions
with torch.no_grad(): with torch.no_grad():
q_preds = self.critic_ensemble(observations, actions, return_type="mean") q_preds = self.critic_ensemble(observations, actions, return_type="mean")
@ -181,36 +178,31 @@ class SACPolicy(
* ~batch["action_is_pad"] * ~batch["action_is_pad"]
).mean() ).mean()
# calculate temperature loss # calculate temperature loss
# 1- calculate entropy # 1- calculate entropy
entropy = -log_probs.mean() entropy = -log_probs.mean()
temperature_loss = self.temp( temperature_loss = self.temp(lhs=entropy, rhs=self.config.target_entropy)
lhs=entropy,
rhs=self.config.target_entropy
)
loss = critics_loss + actor_loss + temperature_loss loss = critics_loss + actor_loss + temperature_loss
return { return {
"critics_loss": critics_loss.item(), "critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(), "actor_loss": actor_loss.item(),
"temperature_loss": temperature_loss.item(), "temperature_loss": temperature_loss.item(),
"temperature": temperature.item(), "temperature": temperature.item(),
"entropy": entropy.item(), "entropy": entropy.item(),
"loss": loss, "loss": loss,
}
}
def update(self): def update(self):
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight) self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
# TODO: implement UTD update # TODO: implement UTD update
# First update only critics for utd_ratio-1 times # First update only critics for utd_ratio-1 times
#for critic_step in range(self.config.utd_ratio - 1): # for critic_step in range(self.config.utd_ratio - 1):
# only update critic and critic target # only update critic and critic target
# Then update critic, critic target, actor and temperature # Then update critic, critic target, actor and temperature
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): # for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight) # target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
@ -227,13 +219,17 @@ class MLP(nn.Module):
layers = [] layers = []
for i, size in enumerate(config.network_hidden_dims): for i, size in enumerate(config.network_hidden_dims):
layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size)) layers.append(
nn.Linear(config.network_hidden_dims[i - 1] if i > 0 else config.network_hidden_dims[0], size)
)
if i + 1 < len(config.network_hidden_dims) or activate_final: if i + 1 < len(config.network_hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0: if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(size)) layers.append(nn.LayerNorm(size))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) layers.append(
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers) self.net = nn.Sequential(*layers)
@ -250,7 +246,7 @@ class Critic(nn.Module):
network: nn.Module, network: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
activate_final: bool = False, activate_final: bool = False,
device: str = "cuda" device: str = "cuda",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
@ -276,12 +272,7 @@ class Critic(nn.Module):
self.to(self.device) self.to(self.device)
def forward( def forward(self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False) -> torch.Tensor:
self,
observations: torch.Tensor,
actions: torch.Tensor,
train: bool = False
) -> torch.Tensor:
self.train(train) self.train(train)
observations = observations.to(self.device) observations = observations.to(self.device)
@ -295,10 +286,7 @@ class Critic(nn.Module):
return value.squeeze(-1) return value.squeeze(-1)
def q_value_ensemble( def q_value_ensemble(
self, self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False
observations: torch.Tensor,
actions: torch.Tensor,
train: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
observations = observations.to(self.device) observations = observations.to(self.device)
actions = actions.to(self.device) actions = actions.to(self.device)
@ -327,7 +315,7 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None, fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None, init_final: Optional[float] = None,
activate_final: bool = False, activate_final: bool = False,
device: str = "cuda" device: str = "cuda",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
@ -374,7 +362,7 @@ class Policy(nn.Module):
observations: torch.Tensor, observations: torch.Tensor,
temperature: float = 1.0, temperature: float = 1.0,
train: bool = False, train: bool = False,
non_squash_distribution: bool = False non_squash_distribution: bool = False,
) -> torch.distributions.Distribution: ) -> torch.distributions.Distribution:
self.train(train) self.train(train)
@ -398,9 +386,7 @@ class Policy(nn.Module):
elif self.std_parameterization == "uniform": elif self.std_parameterization == "uniform":
stds = torch.exp(self.log_stds).expand_as(means) stds = torch.exp(self.log_stds).expand_as(means)
else: else:
raise ValueError( raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}")
f"Invalid std_parameterization: {self.std_parameterization}"
)
else: else:
assert self.std_parameterization == "fixed" assert self.std_parameterization == "fixed"
stds = self.fixed_std.expand_as(means) stds = self.fixed_std.expand_as(means)
@ -506,12 +492,7 @@ class SACObservationEncoder(nn.Module):
class LagrangeMultiplier(nn.Module): class LagrangeMultiplier(nn.Module):
def __init__( def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"):
self,
init_value: float = 1.0,
constraint_shape: Sequence[int] = (),
device: str = "cuda"
):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1) init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
@ -523,11 +504,7 @@ class LagrangeMultiplier(nn.Module):
self.to(self.device) self.to(self.device)
def forward( def forward(self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None) -> torch.Tensor:
self,
lhs: Optional[torch.Tensor] = None,
rhs: Optional[torch.Tensor] = None
) -> torch.Tensor:
# Get the multiplier value based on parameterization # Get the multiplier value based on parameterization
multiplier = torch.nn.functional.softplus(self.lagrange) multiplier = torch.nn.functional.softplus(self.lagrange)
@ -578,17 +555,11 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
# Add rescaling transform if bounds are provided # Add rescaling transform if bounds are provided
if low is not None and high is not None: if low is not None and high is not None:
transforms.append( transforms.append(
torch.distributions.transforms.AffineTransform( torch.distributions.transforms.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2)
loc=(high + low) / 2,
scale=(high - low) / 2
)
) )
# Initialize parent class # Initialize parent class
super().__init__( super().__init__(base_distribution=base_distribution, transforms=transforms)
base_distribution=base_distribution,
transforms=transforms
)
# Store parameters # Store parameters
self.loc = loc self.loc = loc
@ -691,4 +662,3 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
inp = torch.flatten(image_tensor, end_dim=-4) inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp) flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))

View File

@ -53,6 +53,70 @@ from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy from lerobot.scripts.eval import eval_policy
def make_optimizer_and_scheduler(cfg, policy):
if cfg.policy.name == "act":
optimizer_params_dicts = [
{
"params": [
p
for n, p in policy.named_parameters()
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p
for n, p in policy.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": cfg.training.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
)
lr_scheduler = None
elif cfg.policy.name == "diffusion":
optimizer = torch.optim.Adam(
policy.diffusion.parameters(),
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None
elif policy.name == "sac":
optimizer = torch.optim.Adam(
[
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
]
)
lr_scheduler = None
elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
optimizer = VQBeTOptimizer(policy, cfg)
lr_scheduler = VQBeTScheduler(optimizer, cfg)
else:
raise NotImplementedError()
return optimizer, lr_scheduler
def update_policy( def update_policy(
train_metrics: MetricsTracker, train_metrics: MetricsTracker,
policy: PreTrainedPolicy, policy: PreTrainedPolicy,

View File

@ -22,7 +22,6 @@ from pprint import pformat
import hydra import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
import wandb
from deepdiff import DeepDiff from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger from lerobot.common.logger import Logger