Added normalization schemes and style checks

This commit is contained in:
Michel Aractingi 2024-12-29 12:51:21 +00:00
parent 08ec971086
commit dc54d357ca
10 changed files with 150 additions and 156 deletions

View File

@ -25,13 +25,13 @@ from glob import glob
from pathlib import Path
import torch
import wandb
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
import wandb
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state

View File

@ -2,8 +2,6 @@ import json
import os
from dataclasses import asdict, dataclass
import torch
@dataclass
class ClassifierConfig:

View File

@ -23,9 +23,11 @@ class ClassifierOutput:
self.hidden_states = hidden_states
def __repr__(self):
return (f"ClassifierOutput(logits={self.logits}, "
return (
f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})")
f"hidden_states={self.hidden_states})"
)
class Classifier(

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from dataclasses import dataclass, field
@dataclass
@ -41,3 +41,25 @@ class SACConfig:
"tanh_squash_distribution": True,
"std_parameterization": "uniform",
}
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
state_encoder_hidden_dim: int = 256
latent_dim: int = 256
network_hidden_dims: int = 256
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)

View File

@ -40,11 +40,9 @@ class SACPolicy(
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"],
):
def __init__(
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
super().__init__()
if config is None:
@ -67,10 +65,7 @@ class SACPolicy(
# Define networks
critic_nets = []
for _ in range(config.num_critics):
critic_net = Critic(
encoder=encoder,
network=MLP(**config.critic_network_kwargs)
)
critic_net = Critic(encoder=encoder, network=MLP(**config.critic_network_kwargs))
critic_nets.append(critic_net)
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
@ -80,7 +75,7 @@ class SACPolicy(
encoder=encoder,
network=MLP(**config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
**config.policy_kwargs
**config.policy_kwargs,
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
@ -103,7 +98,7 @@ class SACPolicy(
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
actions, _ = self.actor_network(batch['observations'])###
actions, _ = self.actor_network(batch["observations"]) ###
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
@ -128,7 +123,6 @@ class SACPolicy(
# from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss
# 1- compute actions from policy
action_preds, log_probs = self.actor_network(observations)
@ -152,6 +146,7 @@ class SACPolicy(
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = (
(
F.mse_loss(
q_preds,
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
@ -163,15 +158,17 @@ class SACPolicy(
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
).sum(0).mean()
)
.sum(0)
.mean()
)
# calculate actors loss
# 1- temperature
temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor_network(observations) \
actions, log_probs = self.actor_network(observations)
# 3- get q-value predictions
with torch.no_grad():
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
@ -181,14 +178,10 @@ class SACPolicy(
* ~batch["action_is_pad"]
).mean()
# calculate temperature loss
# 1- calculate entropy
entropy = -log_probs.mean()
temperature_loss = self.temp(
lhs=entropy,
rhs=self.config.target_entropy
)
temperature_loss = self.temp(lhs=entropy, rhs=self.config.target_entropy)
loss = critics_loss + actor_loss + temperature_loss
@ -199,7 +192,6 @@ class SACPolicy(
"temperature": temperature.item(),
"entropy": entropy.item(),
"loss": loss,
}
def update(self):
@ -227,13 +219,17 @@ class MLP(nn.Module):
layers = []
for i, size in enumerate(config.network_hidden_dims):
layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size))
layers.append(
nn.Linear(config.network_hidden_dims[i - 1] if i > 0 else config.network_hidden_dims[0], size)
)
if i + 1 < len(config.network_hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(size))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
layers.append(
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers)
@ -250,7 +246,7 @@ class Critic(nn.Module):
network: nn.Module,
init_final: Optional[float] = None,
activate_final: bool = False,
device: str = "cuda"
device: str = "cuda",
):
super().__init__()
self.device = torch.device(device)
@ -276,12 +272,7 @@ class Critic(nn.Module):
self.to(self.device)
def forward(
self,
observations: torch.Tensor,
actions: torch.Tensor,
train: bool = False
) -> torch.Tensor:
def forward(self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False) -> torch.Tensor:
self.train(train)
observations = observations.to(self.device)
@ -295,10 +286,7 @@ class Critic(nn.Module):
return value.squeeze(-1)
def q_value_ensemble(
self,
observations: torch.Tensor,
actions: torch.Tensor,
train: bool = False
self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False
) -> torch.Tensor:
observations = observations.to(self.device)
actions = actions.to(self.device)
@ -327,7 +315,7 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
activate_final: bool = False,
device: str = "cuda"
device: str = "cuda",
):
super().__init__()
self.device = torch.device(device)
@ -374,7 +362,7 @@ class Policy(nn.Module):
observations: torch.Tensor,
temperature: float = 1.0,
train: bool = False,
non_squash_distribution: bool = False
non_squash_distribution: bool = False,
) -> torch.distributions.Distribution:
self.train(train)
@ -398,9 +386,7 @@ class Policy(nn.Module):
elif self.std_parameterization == "uniform":
stds = torch.exp(self.log_stds).expand_as(means)
else:
raise ValueError(
f"Invalid std_parameterization: {self.std_parameterization}"
)
raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}")
else:
assert self.std_parameterization == "fixed"
stds = self.fixed_std.expand_as(means)
@ -506,12 +492,7 @@ class SACObservationEncoder(nn.Module):
class LagrangeMultiplier(nn.Module):
def __init__(
self,
init_value: float = 1.0,
constraint_shape: Sequence[int] = (),
device: str = "cuda"
):
def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"):
super().__init__()
self.device = torch.device(device)
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
@ -523,11 +504,7 @@ class LagrangeMultiplier(nn.Module):
self.to(self.device)
def forward(
self,
lhs: Optional[torch.Tensor] = None,
rhs: Optional[torch.Tensor] = None
) -> torch.Tensor:
def forward(self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None) -> torch.Tensor:
# Get the multiplier value based on parameterization
multiplier = torch.nn.functional.softplus(self.lagrange)
@ -578,17 +555,11 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
# Add rescaling transform if bounds are provided
if low is not None and high is not None:
transforms.append(
torch.distributions.transforms.AffineTransform(
loc=(high + low) / 2,
scale=(high - low) / 2
)
torch.distributions.transforms.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2)
)
# Initialize parent class
super().__init__(
base_distribution=base_distribution,
transforms=transforms
)
super().__init__(base_distribution=base_distribution, transforms=transforms)
# Store parameters
self.loc = loc
@ -691,4 +662,3 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))

View File

@ -95,11 +95,13 @@ def make_optimizer_and_scheduler(cfg, policy):
lr_scheduler = None
elif policy.name == "sac":
optimizer = torch.optim.Adam([
{'params': policy.actor.parameters(), 'lr': policy.config.actor_lr},
{'params': policy.critic_ensemble.parameters(), 'lr': policy.config.critic_lr},
{'params': policy.temperature.parameters(), 'lr': policy.config.temperature_lr},
])
optimizer = torch.optim.Adam(
[
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
]
)
lr_scheduler = None
elif cfg.policy.name == "vqbet":

View File

@ -22,7 +22,6 @@ from pprint import pformat
import hydra
import torch
import torch.nn as nn
import wandb
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
@ -31,6 +30,7 @@ from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger