style fixes

This commit is contained in:
Michel Aractingi 2024-12-29 14:35:21 +00:00 committed by Adil Zouitine
parent 012ef3217e
commit 5070295e59
2 changed files with 37 additions and 39 deletions

View File

@ -19,7 +19,6 @@
from collections import deque from collections import deque
from copy import deepcopy from copy import deepcopy
import math
from typing import Callable, Optional, Sequence, Tuple from typing import Callable, Optional, Sequence, Tuple
import einops import einops
@ -144,7 +143,9 @@ class SACPolicy(
min_q = q_targets.min(dim=0) min_q = q_targets.min(dim=0)
# compute td target # compute td target
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term td_target = (
rewards + self.config.discount * min_q
) # + self.config.discount * self.temperature() * log_probs # add entropy term
# 3- compute predicted qs # 3- compute predicted qs
q_preds = self.critic_ensemble(observations, actions) q_preds = self.critic_ensemble(observations, actions)
@ -170,11 +171,17 @@ class SACPolicy(
# ) # )
# 4- Calculate loss # 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = F.mse_loss( critics_loss = (
F.mse_loss(
q_preds, # shape: [num_critics, batch_size] q_preds, # shape: [num_critics, batch_size]
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape einops.repeat(
reduction="none" td_target, "b -> e b", e=q_preds.shape[0]
).sum(0).mean() ), # expand td_target to match q_preds shape
reduction="none",
)
.sum(0)
.mean()
)
# calculate actors loss # calculate actors loss
# 1- temperature # 1- temperature
@ -255,7 +262,7 @@ class Critic(nn.Module):
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network: nn.Module, network: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
device: str = "cuda" device: str = "cuda",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
@ -305,7 +312,7 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None, fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None, init_final: Optional[float] = None,
use_tanh_squash: bool = False, use_tanh_squash: bool = False,
device: str = "cuda" device: str = "cuda",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
@ -346,13 +353,8 @@ class Policy(nn.Module):
self, self,
observations: torch.Tensor, observations: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists # Encode observations if encoder exists
if self.encoder is not None: obs_enc = observations if self.encoder is not None else self.encoder(observations)
with torch.set_grad_enabled(train):
obs_enc = self.encoder(observations, train=train)
else:
obs_enc = observations
# Get network outputs # Get network outputs
outputs = self.network(obs_enc) outputs = self.network(obs_enc)
@ -507,10 +509,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
return nn.ModuleList(critics).to(device) return nn.ModuleList(critics).to(device)
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
# borrowed from tdmpc # borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor. """Helper to temporarily flatten extra dims at the start of the image tensor.