style fixes

This commit is contained in:
Michel Aractingi 2024-12-29 14:35:21 +00:00 committed by Adil Zouitine
parent 012ef3217e
commit 5070295e59
2 changed files with 37 additions and 39 deletions

View File

@ -53,13 +53,13 @@ class SACConfig:
critic_network_kwargs = { critic_network_kwargs = {
"hidden_dims": [256, 256], "hidden_dims": [256, 256],
"activate_final": True, "activate_final": True,
} }
actor_network_kwargs = { actor_network_kwargs = {
"hidden_dims": [256, 256], "hidden_dims": [256, 256],
"activate_final": True, "activate_final": True,
} }
policy_kwargs = { policy_kwargs = {
"use_tanh_squash": True, "use_tanh_squash": True,
"log_std_min": -5, "log_std_min": -5,
"log_std_max": 2, "log_std_max": 2,
} }

View File

@ -19,7 +19,6 @@
from collections import deque from collections import deque
from copy import deepcopy from copy import deepcopy
import math
from typing import Callable, Optional, Sequence, Tuple from typing import Callable, Optional, Sequence, Tuple
import einops import einops
@ -137,21 +136,23 @@ class SACPolicy(
# subsample critics to prevent overfitting if use high UTD (update to date) # subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None: if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics) indices = torch.randperm(self.config.num_critics)
indices = indices[:self.config.num_subsample_critics] indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices] q_targets = q_targets[indices]
# critics subsample size # critics subsample size
min_q = q_targets.min(dim=0) min_q = q_targets.min(dim=0)
# compute td target # compute td target
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term td_target = (
rewards + self.config.discount * min_q
) # + self.config.discount * self.temperature() * log_probs # add entropy term
# 3- compute predicted qs # 3- compute predicted qs
q_preds = self.critic_ensemble(observations, actions) q_preds = self.critic_ensemble(observations, actions)
# 4- Calculate loss # 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
#critics_loss = ( # critics_loss = (
# ( # (
# F.mse_loss( # F.mse_loss(
# q_preds, # q_preds,
@ -167,14 +168,20 @@ class SACPolicy(
# ) # )
# .sum(0) # .sum(0)
# .mean() # .mean()
#) # )
# 4- Calculate loss # 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = F.mse_loss( critics_loss = (
q_preds, # shape: [num_critics, batch_size] F.mse_loss(
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape q_preds, # shape: [num_critics, batch_size]
reduction="none" einops.repeat(
).sum(0).mean() td_target, "b -> e b", e=q_preds.shape[0]
), # expand td_target to match q_preds shape
reduction="none",
)
.sum(0)
.mean()
)
# calculate actors loss # calculate actors loss
# 1- temperature # 1- temperature
@ -231,7 +238,7 @@ class MLP(nn.Module):
layers = [] layers = []
for i, size in enumerate(hidden_dims): for i, size in enumerate(hidden_dims):
layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size)) layers.append(nn.Linear(hidden_dims[i - 1] if i > 0 else hidden_dims[0], size))
if i + 1 < len(hidden_dims) or activate_final: if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0: if dropout_rate is not None and dropout_rate > 0:
@ -255,7 +262,7 @@ class Critic(nn.Module):
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network: nn.Module, network: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
device: str = "cuda" device: str = "cuda",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
@ -305,7 +312,7 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None, fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None, init_final: Optional[float] = None,
use_tanh_squash: bool = False, use_tanh_squash: bool = False,
device: str = "cuda" device: str = "cuda",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
@ -346,13 +353,8 @@ class Policy(nn.Module):
self, self,
observations: torch.Tensor, observations: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists # Encode observations if encoder exists
if self.encoder is not None: obs_enc = observations if self.encoder is not None else self.encoder(observations)
with torch.set_grad_enabled(train):
obs_enc = self.encoder(observations, train=train)
else:
obs_enc = observations
# Get network outputs # Get network outputs
outputs = self.network(obs_enc) outputs = self.network(obs_enc)
@ -374,7 +376,7 @@ class Policy(nn.Module):
if self.use_tanh_squash: if self.use_tanh_squash:
actions = torch.tanh(x_t) actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
log_probs = log_probs.sum(-1) # sum over action dim log_probs = log_probs.sum(-1) # sum over action dim
return actions, log_probs return actions, log_probs
@ -507,10 +509,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
return nn.ModuleList(critics).to(device) return nn.ModuleList(critics).to(device)
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
# borrowed from tdmpc # borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor. """Helper to temporarily flatten extra dims at the start of the image tensor.