style fixes
This commit is contained in:
parent
91fefdecfa
commit
2c2ed084cc
|
@ -53,13 +53,13 @@ class SACConfig:
|
|||
critic_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
}
|
||||
actor_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
}
|
||||
policy_kwargs = {
|
||||
"use_tanh_squash": True,
|
||||
"log_std_min": -5,
|
||||
"log_std_max": 2,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
import math
|
||||
from typing import Callable, Optional, Sequence, Tuple
|
||||
|
||||
import einops
|
||||
|
@ -125,9 +124,9 @@ class SACPolicy(
|
|||
|
||||
# perform image augmentation
|
||||
|
||||
# reward bias from HIL-SERL code base
|
||||
# reward bias from HIL-SERL code base
|
||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
||||
|
||||
|
||||
# calculate critics loss
|
||||
# 1- compute actions from policy
|
||||
action_preds, log_probs = self.actor(next_observations)
|
||||
|
@ -137,21 +136,23 @@ class SACPolicy(
|
|||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[:self.config.num_subsample_critics]
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q = q_targets.min(dim=0)
|
||||
|
||||
# compute td target
|
||||
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
|
||||
td_target = (
|
||||
rewards + self.config.discount * min_q
|
||||
) # + self.config.discount * self.temperature() * log_probs # add entropy term
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_ensemble(observations, actions)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
#critics_loss = (
|
||||
# critics_loss = (
|
||||
# (
|
||||
# F.mse_loss(
|
||||
# q_preds,
|
||||
|
@ -167,14 +168,20 @@ class SACPolicy(
|
|||
# )
|
||||
# .sum(0)
|
||||
# .mean()
|
||||
#)
|
||||
# )
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
critics_loss = F.mse_loss(
|
||||
q_preds, # shape: [num_critics, batch_size]
|
||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
|
||||
reduction="none"
|
||||
).sum(0).mean()
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
q_preds, # shape: [num_critics, batch_size]
|
||||
einops.repeat(
|
||||
td_target, "b -> e b", e=q_preds.shape[0]
|
||||
), # expand td_target to match q_preds shape
|
||||
reduction="none",
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
|
||||
# calculate actors loss
|
||||
# 1- temperature
|
||||
|
@ -229,10 +236,10 @@ class MLP(nn.Module):
|
|||
super().__init__()
|
||||
self.activate_final = activate_final
|
||||
layers = []
|
||||
|
||||
|
||||
for i, size in enumerate(hidden_dims):
|
||||
layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size))
|
||||
|
||||
layers.append(nn.Linear(hidden_dims[i - 1] if i > 0 else hidden_dims[0], size))
|
||||
|
||||
if i + 1 < len(hidden_dims) or activate_final:
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
|
@ -255,20 +262,20 @@ class Critic(nn.Module):
|
|||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
device: str = "cuda"
|
||||
device: str = "cuda",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.init_final = init_final
|
||||
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
|
||||
# Output layer
|
||||
if init_final is not None:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
|
@ -305,7 +312,7 @@ class Policy(nn.Module):
|
|||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
use_tanh_squash: bool = False,
|
||||
device: str = "cuda"
|
||||
device: str = "cuda",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
|
@ -316,13 +323,13 @@ class Policy(nn.Module):
|
|||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
|
||||
# Mean layer
|
||||
self.mean_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
|
@ -339,21 +346,16 @@ class Policy(nn.Module):
|
|||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# Encode observations if encoder exists
|
||||
if self.encoder is not None:
|
||||
with torch.set_grad_enabled(train):
|
||||
obs_enc = self.encoder(observations, train=train)
|
||||
else:
|
||||
obs_enc = observations
|
||||
|
||||
obs_enc = observations if self.encoder is not None else self.encoder(observations)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
means = self.mean_layer(outputs)
|
||||
|
@ -369,15 +371,15 @@ class Policy(nn.Module):
|
|||
|
||||
# uses tahn activation function to squash the action to be in the range of [-1, 1]
|
||||
normal = torch.distributions.Normal(means, stds)
|
||||
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
|
||||
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
|
||||
log_probs = normal.log_prob(x_t)
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
|
||||
log_probs = log_probs.sum(-1) # sum over action dim
|
||||
log_probs = log_probs.sum(-1) # sum over action dim
|
||||
|
||||
return actions, log_probs
|
||||
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
"""Get encoded features from observations"""
|
||||
observations = observations.to(self.device)
|
||||
|
@ -507,10 +509,6 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
|
|||
return nn.ModuleList(critics).to(device)
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
|
Loading…
Reference in New Issue