style fixes
This commit is contained in:
parent
012ef3217e
commit
5070295e59
|
@ -53,13 +53,13 @@ class SACConfig:
|
||||||
critic_network_kwargs = {
|
critic_network_kwargs = {
|
||||||
"hidden_dims": [256, 256],
|
"hidden_dims": [256, 256],
|
||||||
"activate_final": True,
|
"activate_final": True,
|
||||||
}
|
}
|
||||||
actor_network_kwargs = {
|
actor_network_kwargs = {
|
||||||
"hidden_dims": [256, 256],
|
"hidden_dims": [256, 256],
|
||||||
"activate_final": True,
|
"activate_final": True,
|
||||||
}
|
}
|
||||||
policy_kwargs = {
|
policy_kwargs = {
|
||||||
"use_tanh_squash": True,
|
"use_tanh_squash": True,
|
||||||
"log_std_min": -5,
|
"log_std_min": -5,
|
||||||
"log_std_max": 2,
|
"log_std_max": 2,
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import math
|
|
||||||
from typing import Callable, Optional, Sequence, Tuple
|
from typing import Callable, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
@ -137,21 +136,23 @@ class SACPolicy(
|
||||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||||
if self.config.num_subsample_critics is not None:
|
if self.config.num_subsample_critics is not None:
|
||||||
indices = torch.randperm(self.config.num_critics)
|
indices = torch.randperm(self.config.num_critics)
|
||||||
indices = indices[:self.config.num_subsample_critics]
|
indices = indices[: self.config.num_subsample_critics]
|
||||||
q_targets = q_targets[indices]
|
q_targets = q_targets[indices]
|
||||||
|
|
||||||
# critics subsample size
|
# critics subsample size
|
||||||
min_q = q_targets.min(dim=0)
|
min_q = q_targets.min(dim=0)
|
||||||
|
|
||||||
# compute td target
|
# compute td target
|
||||||
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
|
td_target = (
|
||||||
|
rewards + self.config.discount * min_q
|
||||||
|
) # + self.config.discount * self.temperature() * log_probs # add entropy term
|
||||||
|
|
||||||
# 3- compute predicted qs
|
# 3- compute predicted qs
|
||||||
q_preds = self.critic_ensemble(observations, actions)
|
q_preds = self.critic_ensemble(observations, actions)
|
||||||
|
|
||||||
# 4- Calculate loss
|
# 4- Calculate loss
|
||||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
#critics_loss = (
|
# critics_loss = (
|
||||||
# (
|
# (
|
||||||
# F.mse_loss(
|
# F.mse_loss(
|
||||||
# q_preds,
|
# q_preds,
|
||||||
|
@ -167,14 +168,20 @@ class SACPolicy(
|
||||||
# )
|
# )
|
||||||
# .sum(0)
|
# .sum(0)
|
||||||
# .mean()
|
# .mean()
|
||||||
#)
|
# )
|
||||||
# 4- Calculate loss
|
# 4- Calculate loss
|
||||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
critics_loss = F.mse_loss(
|
critics_loss = (
|
||||||
q_preds, # shape: [num_critics, batch_size]
|
F.mse_loss(
|
||||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
|
q_preds, # shape: [num_critics, batch_size]
|
||||||
reduction="none"
|
einops.repeat(
|
||||||
).sum(0).mean()
|
td_target, "b -> e b", e=q_preds.shape[0]
|
||||||
|
), # expand td_target to match q_preds shape
|
||||||
|
reduction="none",
|
||||||
|
)
|
||||||
|
.sum(0)
|
||||||
|
.mean()
|
||||||
|
)
|
||||||
|
|
||||||
# calculate actors loss
|
# calculate actors loss
|
||||||
# 1- temperature
|
# 1- temperature
|
||||||
|
@ -231,7 +238,7 @@ class MLP(nn.Module):
|
||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for i, size in enumerate(hidden_dims):
|
for i, size in enumerate(hidden_dims):
|
||||||
layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size))
|
layers.append(nn.Linear(hidden_dims[i - 1] if i > 0 else hidden_dims[0], size))
|
||||||
|
|
||||||
if i + 1 < len(hidden_dims) or activate_final:
|
if i + 1 < len(hidden_dims) or activate_final:
|
||||||
if dropout_rate is not None and dropout_rate > 0:
|
if dropout_rate is not None and dropout_rate > 0:
|
||||||
|
@ -255,7 +262,7 @@ class Critic(nn.Module):
|
||||||
encoder: Optional[nn.Module],
|
encoder: Optional[nn.Module],
|
||||||
network: nn.Module,
|
network: nn.Module,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
device: str = "cuda"
|
device: str = "cuda",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
|
@ -305,7 +312,7 @@ class Policy(nn.Module):
|
||||||
fixed_std: Optional[torch.Tensor] = None,
|
fixed_std: Optional[torch.Tensor] = None,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
use_tanh_squash: bool = False,
|
use_tanh_squash: bool = False,
|
||||||
device: str = "cuda"
|
device: str = "cuda",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
|
@ -346,13 +353,8 @@ class Policy(nn.Module):
|
||||||
self,
|
self,
|
||||||
observations: torch.Tensor,
|
observations: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
# Encode observations if encoder exists
|
# Encode observations if encoder exists
|
||||||
if self.encoder is not None:
|
obs_enc = observations if self.encoder is not None else self.encoder(observations)
|
||||||
with torch.set_grad_enabled(train):
|
|
||||||
obs_enc = self.encoder(observations, train=train)
|
|
||||||
else:
|
|
||||||
obs_enc = observations
|
|
||||||
|
|
||||||
# Get network outputs
|
# Get network outputs
|
||||||
outputs = self.network(obs_enc)
|
outputs = self.network(obs_enc)
|
||||||
|
@ -374,7 +376,7 @@ class Policy(nn.Module):
|
||||||
if self.use_tanh_squash:
|
if self.use_tanh_squash:
|
||||||
actions = torch.tanh(x_t)
|
actions = torch.tanh(x_t)
|
||||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
|
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
|
||||||
log_probs = log_probs.sum(-1) # sum over action dim
|
log_probs = log_probs.sum(-1) # sum over action dim
|
||||||
|
|
||||||
return actions, log_probs
|
return actions, log_probs
|
||||||
|
|
||||||
|
@ -507,10 +509,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
|
||||||
return nn.ModuleList(critics).to(device)
|
return nn.ModuleList(critics).to(device)
|
||||||
|
|
||||||
|
|
||||||
def orthogonal_init():
|
|
||||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
|
||||||
|
|
||||||
|
|
||||||
# borrowed from tdmpc
|
# borrowed from tdmpc
|
||||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||||
|
|
Loading…
Reference in New Issue