style fixes

This commit is contained in:
Michel Aractingi 2024-12-29 14:35:21 +00:00 committed by AdilZouitine
parent 91fefdecfa
commit 2c2ed084cc
2 changed files with 37 additions and 39 deletions

View File

@ -53,13 +53,13 @@ class SACConfig:
critic_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,
}
}
actor_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,
}
}
policy_kwargs = {
"use_tanh_squash": True,
"log_std_min": -5,
"log_std_max": 2,
}
}

View File

@ -19,7 +19,6 @@
from collections import deque
from copy import deepcopy
import math
from typing import Callable, Optional, Sequence, Tuple
import einops
@ -125,9 +124,9 @@ class SACPolicy(
# perform image augmentation
# reward bias from HIL-SERL code base
# reward bias from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss
# 1- compute actions from policy
action_preds, log_probs = self.actor(next_observations)
@ -137,21 +136,23 @@ class SACPolicy(
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[:self.config.num_subsample_critics]
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q = q_targets.min(dim=0)
# compute td target
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
td_target = (
rewards + self.config.discount * min_q
) # + self.config.discount * self.temperature() * log_probs # add entropy term
# 3- compute predicted qs
q_preds = self.critic_ensemble(observations, actions)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
#critics_loss = (
# critics_loss = (
# (
# F.mse_loss(
# q_preds,
@ -167,14 +168,20 @@ class SACPolicy(
# )
# .sum(0)
# .mean()
#)
# )
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = F.mse_loss(
q_preds, # shape: [num_critics, batch_size]
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
reduction="none"
).sum(0).mean()
critics_loss = (
F.mse_loss(
q_preds, # shape: [num_critics, batch_size]
einops.repeat(
td_target, "b -> e b", e=q_preds.shape[0]
), # expand td_target to match q_preds shape
reduction="none",
)
.sum(0)
.mean()
)
# calculate actors loss
# 1- temperature
@ -229,10 +236,10 @@ class MLP(nn.Module):
super().__init__()
self.activate_final = activate_final
layers = []
for i, size in enumerate(hidden_dims):
layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size))
layers.append(nn.Linear(hidden_dims[i - 1] if i > 0 else hidden_dims[0], size))
if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
@ -255,20 +262,20 @@ class Critic(nn.Module):
encoder: Optional[nn.Module],
network: nn.Module,
init_final: Optional[float] = None,
device: str = "cuda"
device: str = "cuda",
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.init_final = init_final
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Output layer
if init_final is not None:
self.output_layer = nn.Linear(out_features, 1)
@ -305,7 +312,7 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
use_tanh_squash: bool = False,
device: str = "cuda"
device: str = "cuda",
):
super().__init__()
self.device = torch.device(device)
@ -316,13 +323,13 @@ class Policy(nn.Module):
self.log_std_max = log_std_max
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.use_tanh_squash = use_tanh_squash
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Mean layer
self.mean_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
@ -339,21 +346,16 @@ class Policy(nn.Module):
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.std_layer.weight)
self.to(self.device)
def forward(
self,
observations: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
if self.encoder is not None:
with torch.set_grad_enabled(train):
obs_enc = self.encoder(observations, train=train)
else:
obs_enc = observations
obs_enc = observations if self.encoder is not None else self.encoder(observations)
# Get network outputs
outputs = self.network(obs_enc)
means = self.mean_layer(outputs)
@ -369,15 +371,15 @@ class Policy(nn.Module):
# uses tahn activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, stds)
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t)
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
log_probs = log_probs.sum(-1) # sum over action dim
log_probs = log_probs.sum(-1) # sum over action dim
return actions, log_probs
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""
observations = observations.to(self.device)
@ -507,10 +509,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
return nn.ModuleList(critics).to(device)
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
# borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.