style fixes

This commit is contained in:
Michel Aractingi 2024-12-29 14:35:21 +00:00 committed by Adil Zouitine
parent 012ef3217e
commit 5070295e59
2 changed files with 37 additions and 39 deletions

View File

@ -19,7 +19,6 @@
from collections import deque
from copy import deepcopy
import math
from typing import Callable, Optional, Sequence, Tuple
import einops
@ -137,21 +136,23 @@ class SACPolicy(
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[:self.config.num_subsample_critics]
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q = q_targets.min(dim=0)
# compute td target
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
td_target = (
rewards + self.config.discount * min_q
) # + self.config.discount * self.temperature() * log_probs # add entropy term
# 3- compute predicted qs
q_preds = self.critic_ensemble(observations, actions)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
#critics_loss = (
# critics_loss = (
# (
# F.mse_loss(
# q_preds,
@ -167,14 +168,20 @@ class SACPolicy(
# )
# .sum(0)
# .mean()
#)
# )
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = F.mse_loss(
critics_loss = (
F.mse_loss(
q_preds, # shape: [num_critics, batch_size]
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
reduction="none"
).sum(0).mean()
einops.repeat(
td_target, "b -> e b", e=q_preds.shape[0]
), # expand td_target to match q_preds shape
reduction="none",
)
.sum(0)
.mean()
)
# calculate actors loss
# 1- temperature
@ -231,7 +238,7 @@ class MLP(nn.Module):
layers = []
for i, size in enumerate(hidden_dims):
layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size))
layers.append(nn.Linear(hidden_dims[i - 1] if i > 0 else hidden_dims[0], size))
if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0:
@ -255,7 +262,7 @@ class Critic(nn.Module):
encoder: Optional[nn.Module],
network: nn.Module,
init_final: Optional[float] = None,
device: str = "cuda"
device: str = "cuda",
):
super().__init__()
self.device = torch.device(device)
@ -305,7 +312,7 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
use_tanh_squash: bool = False,
device: str = "cuda"
device: str = "cuda",
):
super().__init__()
self.device = torch.device(device)
@ -346,13 +353,8 @@ class Policy(nn.Module):
self,
observations: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
if self.encoder is not None:
with torch.set_grad_enabled(train):
obs_enc = self.encoder(observations, train=train)
else:
obs_enc = observations
obs_enc = observations if self.encoder is not None else self.encoder(observations)
# Get network outputs
outputs = self.network(obs_enc)
@ -507,10 +509,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
return nn.ModuleList(critics).to(device)
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
# borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.