style fixes

This commit is contained in:
Michel Aractingi 2024-12-29 14:35:21 +00:00 committed by Adil Zouitine
parent 012ef3217e
commit 5070295e59
2 changed files with 37 additions and 39 deletions

View File

@ -53,13 +53,13 @@ class SACConfig:
critic_network_kwargs = { critic_network_kwargs = {
"hidden_dims": [256, 256], "hidden_dims": [256, 256],
"activate_final": True, "activate_final": True,
} }
actor_network_kwargs = { actor_network_kwargs = {
"hidden_dims": [256, 256], "hidden_dims": [256, 256],
"activate_final": True, "activate_final": True,
} }
policy_kwargs = { policy_kwargs = {
"use_tanh_squash": True, "use_tanh_squash": True,
"log_std_min": -5, "log_std_min": -5,
"log_std_max": 2, "log_std_max": 2,
} }

View File

@ -19,7 +19,6 @@
from collections import deque from collections import deque
from copy import deepcopy from copy import deepcopy
import math
from typing import Callable, Optional, Sequence, Tuple from typing import Callable, Optional, Sequence, Tuple
import einops import einops
@ -125,9 +124,9 @@ class SACPolicy(
# perform image augmentation # perform image augmentation
# reward bias from HIL-SERL code base # reward bias from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss # calculate critics loss
# 1- compute actions from policy # 1- compute actions from policy
action_preds, log_probs = self.actor(next_observations) action_preds, log_probs = self.actor(next_observations)
@ -137,21 +136,23 @@ class SACPolicy(
# subsample critics to prevent overfitting if use high UTD (update to date) # subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None: if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics) indices = torch.randperm(self.config.num_critics)
indices = indices[:self.config.num_subsample_critics] indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices] q_targets = q_targets[indices]
# critics subsample size # critics subsample size
min_q = q_targets.min(dim=0) min_q = q_targets.min(dim=0)
# compute td target # compute td target
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term td_target = (
rewards + self.config.discount * min_q
) # + self.config.discount * self.temperature() * log_probs # add entropy term
# 3- compute predicted qs # 3- compute predicted qs
q_preds = self.critic_ensemble(observations, actions) q_preds = self.critic_ensemble(observations, actions)
# 4- Calculate loss # 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
#critics_loss = ( # critics_loss = (
# ( # (
# F.mse_loss( # F.mse_loss(
# q_preds, # q_preds,
@ -167,14 +168,20 @@ class SACPolicy(
# ) # )
# .sum(0) # .sum(0)
# .mean() # .mean()
#) # )
# 4- Calculate loss # 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = F.mse_loss( critics_loss = (
q_preds, # shape: [num_critics, batch_size] F.mse_loss(
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape q_preds, # shape: [num_critics, batch_size]
reduction="none" einops.repeat(
).sum(0).mean() td_target, "b -> e b", e=q_preds.shape[0]
), # expand td_target to match q_preds shape
reduction="none",
)
.sum(0)
.mean()
)
# calculate actors loss # calculate actors loss
# 1- temperature # 1- temperature
@ -229,10 +236,10 @@ class MLP(nn.Module):
super().__init__() super().__init__()
self.activate_final = activate_final self.activate_final = activate_final
layers = [] layers = []
for i, size in enumerate(hidden_dims): for i, size in enumerate(hidden_dims):
layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size)) layers.append(nn.Linear(hidden_dims[i - 1] if i > 0 else hidden_dims[0], size))
if i + 1 < len(hidden_dims) or activate_final: if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0: if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.Dropout(p=dropout_rate))
@ -255,20 +262,20 @@ class Critic(nn.Module):
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network: nn.Module, network: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
device: str = "cuda" device: str = "cuda",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
self.encoder = encoder self.encoder = encoder
self.network = network self.network = network
self.init_final = init_final self.init_final = init_final
# Find the last Linear layer's output dimension # Find the last Linear layer's output dimension
for layer in reversed(network.net): for layer in reversed(network.net):
if isinstance(layer, nn.Linear): if isinstance(layer, nn.Linear):
out_features = layer.out_features out_features = layer.out_features
break break
# Output layer # Output layer
if init_final is not None: if init_final is not None:
self.output_layer = nn.Linear(out_features, 1) self.output_layer = nn.Linear(out_features, 1)
@ -305,7 +312,7 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None, fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None, init_final: Optional[float] = None,
use_tanh_squash: bool = False, use_tanh_squash: bool = False,
device: str = "cuda" device: str = "cuda",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
@ -316,13 +323,13 @@ class Policy(nn.Module):
self.log_std_max = log_std_max self.log_std_max = log_std_max
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.use_tanh_squash = use_tanh_squash self.use_tanh_squash = use_tanh_squash
# Find the last Linear layer's output dimension # Find the last Linear layer's output dimension
for layer in reversed(network.net): for layer in reversed(network.net):
if isinstance(layer, nn.Linear): if isinstance(layer, nn.Linear):
out_features = layer.out_features out_features = layer.out_features
break break
# Mean layer # Mean layer
self.mean_layer = nn.Linear(out_features, action_dim) self.mean_layer = nn.Linear(out_features, action_dim)
if init_final is not None: if init_final is not None:
@ -339,21 +346,16 @@ class Policy(nn.Module):
nn.init.uniform_(self.std_layer.bias, -init_final, init_final) nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else: else:
orthogonal_init()(self.std_layer.weight) orthogonal_init()(self.std_layer.weight)
self.to(self.device) self.to(self.device)
def forward( def forward(
self, self,
observations: torch.Tensor, observations: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists # Encode observations if encoder exists
if self.encoder is not None: obs_enc = observations if self.encoder is not None else self.encoder(observations)
with torch.set_grad_enabled(train):
obs_enc = self.encoder(observations, train=train)
else:
obs_enc = observations
# Get network outputs # Get network outputs
outputs = self.network(obs_enc) outputs = self.network(obs_enc)
means = self.mean_layer(outputs) means = self.mean_layer(outputs)
@ -369,15 +371,15 @@ class Policy(nn.Module):
# uses tahn activation function to squash the action to be in the range of [-1, 1] # uses tahn activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, stds) normal = torch.distributions.Normal(means, stds)
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t) log_probs = normal.log_prob(x_t)
if self.use_tanh_squash: if self.use_tanh_squash:
actions = torch.tanh(x_t) actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
log_probs = log_probs.sum(-1) # sum over action dim log_probs = log_probs.sum(-1) # sum over action dim
return actions, log_probs return actions, log_probs
def get_features(self, observations: torch.Tensor) -> torch.Tensor: def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations""" """Get encoded features from observations"""
observations = observations.to(self.device) observations = observations.to(self.device)
@ -507,10 +509,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
return nn.ModuleList(critics).to(device) return nn.ModuleList(critics).to(device)
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
# borrowed from tdmpc # borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor. """Helper to temporarily flatten extra dims at the start of the image tensor.