style fixes
This commit is contained in:
parent
9968dd06e0
commit
3b86ed79db
|
@ -19,7 +19,6 @@
|
|||
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
import math
|
||||
from typing import Callable, Optional, Sequence, Tuple
|
||||
|
||||
import einops
|
||||
|
@ -137,21 +136,23 @@ class SACPolicy(
|
|||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[:self.config.num_subsample_critics]
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q = q_targets.min(dim=0)
|
||||
|
||||
# compute td target
|
||||
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
|
||||
td_target = (
|
||||
rewards + self.config.discount * min_q
|
||||
) # + self.config.discount * self.temperature() * log_probs # add entropy term
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_ensemble(observations, actions)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
#critics_loss = (
|
||||
# critics_loss = (
|
||||
# (
|
||||
# F.mse_loss(
|
||||
# q_preds,
|
||||
|
@ -167,14 +168,20 @@ class SACPolicy(
|
|||
# )
|
||||
# .sum(0)
|
||||
# .mean()
|
||||
#)
|
||||
# )
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
critics_loss = F.mse_loss(
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
q_preds, # shape: [num_critics, batch_size]
|
||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
|
||||
reduction="none"
|
||||
).sum(0).mean()
|
||||
einops.repeat(
|
||||
td_target, "b -> e b", e=q_preds.shape[0]
|
||||
), # expand td_target to match q_preds shape
|
||||
reduction="none",
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
|
||||
# calculate actors loss
|
||||
# 1- temperature
|
||||
|
@ -231,7 +238,7 @@ class MLP(nn.Module):
|
|||
layers = []
|
||||
|
||||
for i, size in enumerate(hidden_dims):
|
||||
layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size))
|
||||
layers.append(nn.Linear(hidden_dims[i - 1] if i > 0 else hidden_dims[0], size))
|
||||
|
||||
if i + 1 < len(hidden_dims) or activate_final:
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
|
@ -255,7 +262,7 @@ class Critic(nn.Module):
|
|||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
device: str = "cuda"
|
||||
device: str = "cuda",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
|
@ -305,7 +312,7 @@ class Policy(nn.Module):
|
|||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
use_tanh_squash: bool = False,
|
||||
device: str = "cuda"
|
||||
device: str = "cuda",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
|
@ -346,13 +353,8 @@ class Policy(nn.Module):
|
|||
self,
|
||||
observations: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# Encode observations if encoder exists
|
||||
if self.encoder is not None:
|
||||
with torch.set_grad_enabled(train):
|
||||
obs_enc = self.encoder(observations, train=train)
|
||||
else:
|
||||
obs_enc = observations
|
||||
obs_enc = observations if self.encoder is not None else self.encoder(observations)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
|
@ -507,10 +509,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
|
|||
return nn.ModuleList(critics).to(device)
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
|
Loading…
Reference in New Issue