2024-12-12 18:45:30 +08:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
2024-12-29 20:51:21 +08:00
|
|
|
# Copyright 2024 The HuggingFace Inc. team.
|
2024-12-12 18:45:30 +08:00
|
|
|
# All rights reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
# TODO: (1) better device management
|
|
|
|
|
2024-12-12 18:45:30 +08:00
|
|
|
from collections import deque
|
2024-12-17 21:26:17 +08:00
|
|
|
from copy import deepcopy
|
2024-12-23 17:44:29 +08:00
|
|
|
from typing import Callable, Optional, Sequence, Tuple
|
2024-12-12 18:45:30 +08:00
|
|
|
|
|
|
|
import einops
|
2024-12-23 17:44:29 +08:00
|
|
|
import numpy as np
|
2024-12-12 18:45:30 +08:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F # noqa: N812
|
2024-12-23 17:44:29 +08:00
|
|
|
from huggingface_hub import PyTorchModelHubMixin
|
2024-12-12 18:45:30 +08:00
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
|
|
|
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
2024-12-17 21:26:17 +08:00
|
|
|
|
2024-12-12 18:45:30 +08:00
|
|
|
|
|
|
|
class SACPolicy(
|
|
|
|
nn.Module,
|
|
|
|
PyTorchModelHubMixin,
|
|
|
|
library_name="lerobot",
|
|
|
|
repo_url="https://github.com/huggingface/lerobot",
|
|
|
|
tags=["robotics", "RL", "SAC"],
|
|
|
|
):
|
2024-12-27 07:38:46 +08:00
|
|
|
name = "sac"
|
|
|
|
|
2024-12-12 18:45:30 +08:00
|
|
|
def __init__(
|
|
|
|
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
if config is None:
|
|
|
|
config = SACConfig()
|
|
|
|
self.config = config
|
|
|
|
|
|
|
|
if config.input_normalization_modes is not None:
|
|
|
|
self.normalize_inputs = Normalize(
|
|
|
|
config.input_shapes, config.input_normalization_modes, dataset_stats
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
self.normalize_inputs = nn.Identity()
|
|
|
|
self.normalize_targets = Normalize(
|
|
|
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
|
|
|
)
|
|
|
|
self.unnormalize_outputs = Unnormalize(
|
|
|
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
|
|
|
)
|
2024-12-17 21:26:17 +08:00
|
|
|
encoder = SACObservationEncoder(config)
|
|
|
|
# Define networks
|
|
|
|
critic_nets = []
|
|
|
|
for _ in range(config.num_critics):
|
2024-12-29 20:51:21 +08:00
|
|
|
critic_net = Critic(encoder=encoder, network=MLP(**config.critic_network_kwargs))
|
2024-12-17 21:26:17 +08:00
|
|
|
critic_nets.append(critic_net)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
|
|
|
self.critic_target = deepcopy(self.critic_ensemble)
|
2024-12-12 18:45:30 +08:00
|
|
|
|
2024-12-27 07:38:46 +08:00
|
|
|
self.actor = Policy(
|
2024-12-17 21:26:17 +08:00
|
|
|
encoder=encoder,
|
|
|
|
network=MLP(**config.actor_network_kwargs),
|
|
|
|
action_dim=config.output_shapes["action"][0],
|
2024-12-29 20:51:21 +08:00
|
|
|
**config.policy_kwargs,
|
2024-12-17 21:26:17 +08:00
|
|
|
)
|
2024-12-17 23:58:04 +08:00
|
|
|
if config.target_entropy is None:
|
2024-12-29 20:51:21 +08:00
|
|
|
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
|
|
|
|
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
2024-12-12 18:45:30 +08:00
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
"""
|
|
|
|
Clear observation and action queues. Should be called on `env.reset()`
|
|
|
|
queues are populated during rollout of the policy, they contain the n latest observations and actions
|
|
|
|
"""
|
|
|
|
|
|
|
|
self._queues = {
|
|
|
|
"observation.state": deque(maxlen=1),
|
|
|
|
"action": deque(maxlen=1),
|
|
|
|
}
|
2024-12-27 07:38:46 +08:00
|
|
|
if "observation.image" in self.config.input_shapes:
|
2024-12-12 18:45:30 +08:00
|
|
|
self._queues["observation.image"] = deque(maxlen=1)
|
2024-12-27 07:38:46 +08:00
|
|
|
if "observation.environment_state" in self.config.input_shapes:
|
2024-12-12 18:45:30 +08:00
|
|
|
self._queues["observation.environment_state"] = deque(maxlen=1)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-12 18:45:30 +08:00
|
|
|
@torch.no_grad()
|
|
|
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
2024-12-29 02:07:15 +08:00
|
|
|
"""Select action for inference/evaluation"""
|
2024-12-29 20:30:39 +08:00
|
|
|
actions, _ = self.actor(batch)
|
2024-12-29 02:07:15 +08:00
|
|
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
|
|
|
return actions
|
2024-12-12 18:45:30 +08:00
|
|
|
|
|
|
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
|
|
|
"""Run the batch through the model and compute the loss.
|
|
|
|
|
|
|
|
Returns a dictionary with loss as a tensor, and other information as native floats.
|
|
|
|
"""
|
|
|
|
batch = self.normalize_inputs(batch)
|
2024-12-29 20:51:21 +08:00
|
|
|
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
|
|
|
# the next observation for caluculating the right td index.
|
2024-12-12 18:45:30 +08:00
|
|
|
actions = batch["action"][:, 0]
|
|
|
|
rewards = batch["next.reward"][:, 0]
|
|
|
|
observations = {}
|
|
|
|
next_observations = {}
|
|
|
|
for k in batch:
|
|
|
|
if k.startswith("observation."):
|
|
|
|
observations[k] = batch[k][:, 0]
|
|
|
|
next_observations[k] = batch[k][:, 1]
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-12 18:45:30 +08:00
|
|
|
# perform image augmentation
|
|
|
|
|
2024-12-29 22:35:21 +08:00
|
|
|
# reward bias from HIL-SERL code base
|
2024-12-12 18:45:30 +08:00
|
|
|
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
2024-12-29 22:35:21 +08:00
|
|
|
|
2024-12-12 18:45:30 +08:00
|
|
|
# calculate critics loss
|
|
|
|
# 1- compute actions from policy
|
2024-12-29 20:30:39 +08:00
|
|
|
action_preds, log_probs = self.actor(next_observations)
|
|
|
|
|
2024-12-12 18:45:30 +08:00
|
|
|
# 2- compute q targets
|
|
|
|
q_targets = self.target_qs(next_observations, action_preds)
|
2024-12-17 23:58:04 +08:00
|
|
|
# subsample critics to prevent overfitting if use high UTD (update to date)
|
|
|
|
if self.config.num_subsample_critics is not None:
|
|
|
|
indices = torch.randperm(self.config.num_critics)
|
2024-12-29 22:35:21 +08:00
|
|
|
indices = indices[: self.config.num_subsample_critics]
|
2024-12-17 23:58:04 +08:00
|
|
|
q_targets = q_targets[indices]
|
2024-12-12 18:45:30 +08:00
|
|
|
|
|
|
|
# critics subsample size
|
|
|
|
min_q = q_targets.min(dim=0)
|
|
|
|
|
2024-12-17 23:58:04 +08:00
|
|
|
# compute td target
|
2024-12-29 22:35:21 +08:00
|
|
|
td_target = (
|
|
|
|
rewards + self.config.discount * min_q
|
|
|
|
) # + self.config.discount * self.temperature() * log_probs # add entropy term
|
2024-12-12 18:45:30 +08:00
|
|
|
|
|
|
|
# 3- compute predicted qs
|
|
|
|
q_preds = self.critic_ensemble(observations, actions)
|
|
|
|
|
|
|
|
# 4- Calculate loss
|
|
|
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
2024-12-29 22:35:21 +08:00
|
|
|
# critics_loss = (
|
2024-12-29 02:07:15 +08:00
|
|
|
# (
|
|
|
|
# F.mse_loss(
|
|
|
|
# q_preds,
|
|
|
|
# einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
|
|
|
|
# reduction="none",
|
|
|
|
# ).sum(0) # sum over ensemble
|
|
|
|
# # `q_preds_ensemble` depends on the first observation and the actions.
|
|
|
|
# * ~batch["observation.state_is_pad"][0]
|
|
|
|
# * ~batch["action_is_pad"]
|
|
|
|
# # q_targets depends on the reward and the next observations.
|
|
|
|
# * ~batch["next.reward_is_pad"]
|
|
|
|
# * ~batch["observation.state_is_pad"][1:]
|
|
|
|
# )
|
|
|
|
# .sum(0)
|
|
|
|
# .mean()
|
2024-12-29 22:35:21 +08:00
|
|
|
# )
|
2024-12-29 02:07:15 +08:00
|
|
|
# 4- Calculate loss
|
|
|
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
2024-12-29 22:35:21 +08:00
|
|
|
critics_loss = (
|
|
|
|
F.mse_loss(
|
|
|
|
q_preds, # shape: [num_critics, batch_size]
|
|
|
|
einops.repeat(
|
|
|
|
td_target, "b -> e b", e=q_preds.shape[0]
|
|
|
|
), # expand td_target to match q_preds shape
|
|
|
|
reduction="none",
|
|
|
|
)
|
|
|
|
.sum(0)
|
|
|
|
.mean()
|
|
|
|
)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-12 18:45:30 +08:00
|
|
|
# calculate actors loss
|
|
|
|
# 1- temperature
|
|
|
|
temperature = self.temperature()
|
|
|
|
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
2024-12-29 20:30:39 +08:00
|
|
|
actions, log_probs = self.actor(observations)
|
2024-12-12 18:45:30 +08:00
|
|
|
# 3- get q-value predictions
|
|
|
|
with torch.no_grad():
|
|
|
|
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
|
|
|
actor_loss = (
|
|
|
|
-(q_preds - temperature * log_probs).mean()
|
|
|
|
* ~batch["observation.state_is_pad"][0]
|
|
|
|
* ~batch["action_is_pad"]
|
|
|
|
).mean()
|
|
|
|
|
|
|
|
# calculate temperature loss
|
|
|
|
# 1- calculate entropy
|
|
|
|
entropy = -log_probs.mean()
|
2024-12-29 20:51:21 +08:00
|
|
|
temperature_loss = self.temp(lhs=entropy, rhs=self.config.target_entropy)
|
2024-12-12 18:45:30 +08:00
|
|
|
|
|
|
|
loss = critics_loss + actor_loss + temperature_loss
|
|
|
|
|
|
|
|
return {
|
2024-12-29 20:51:21 +08:00
|
|
|
"critics_loss": critics_loss.item(),
|
|
|
|
"actor_loss": actor_loss.item(),
|
|
|
|
"temperature_loss": temperature_loss.item(),
|
|
|
|
"temperature": temperature.item(),
|
|
|
|
"entropy": entropy.item(),
|
|
|
|
"loss": loss,
|
|
|
|
}
|
|
|
|
|
2024-12-12 18:45:30 +08:00
|
|
|
def update(self):
|
|
|
|
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
2024-12-17 23:58:04 +08:00
|
|
|
# TODO: implement UTD update
|
2024-12-18 01:03:46 +08:00
|
|
|
# First update only critics for utd_ratio-1 times
|
2024-12-29 20:51:21 +08:00
|
|
|
# for critic_step in range(self.config.utd_ratio - 1):
|
|
|
|
# only update critic and critic target
|
2024-12-17 23:58:04 +08:00
|
|
|
# Then update critic, critic target, actor and temperature
|
|
|
|
|
2024-12-29 20:51:21 +08:00
|
|
|
# for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
2024-12-12 18:45:30 +08:00
|
|
|
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
|
|
|
|
class MLP(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
self,
|
2024-12-27 07:38:46 +08:00
|
|
|
hidden_dims: list[int],
|
2024-12-17 21:26:17 +08:00
|
|
|
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
|
|
|
activate_final: bool = False,
|
|
|
|
dropout_rate: Optional[float] = None,
|
|
|
|
):
|
|
|
|
super().__init__()
|
2024-12-27 07:38:46 +08:00
|
|
|
self.activate_final = activate_final
|
2024-12-17 21:26:17 +08:00
|
|
|
layers = []
|
2024-12-29 22:35:21 +08:00
|
|
|
|
2024-12-27 07:38:46 +08:00
|
|
|
for i, size in enumerate(hidden_dims):
|
2024-12-29 22:35:21 +08:00
|
|
|
layers.append(nn.Linear(hidden_dims[i - 1] if i > 0 else hidden_dims[0], size))
|
|
|
|
|
2024-12-27 07:38:46 +08:00
|
|
|
if i + 1 < len(hidden_dims) or activate_final:
|
2024-12-17 21:26:17 +08:00
|
|
|
if dropout_rate is not None and dropout_rate > 0:
|
|
|
|
layers.append(nn.Dropout(p=dropout_rate))
|
|
|
|
layers.append(nn.LayerNorm(size))
|
2024-12-29 20:51:21 +08:00
|
|
|
layers.append(
|
|
|
|
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
|
|
|
)
|
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
self.net = nn.Sequential(*layers)
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor:
|
|
|
|
# in training mode or not. TODO: find better way to do this
|
2024-12-29 20:51:21 +08:00
|
|
|
self.train(train)
|
2024-12-17 21:26:17 +08:00
|
|
|
return self.net(x)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
class Critic(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
encoder: Optional[nn.Module],
|
|
|
|
network: nn.Module,
|
|
|
|
init_final: Optional[float] = None,
|
2024-12-29 22:35:21 +08:00
|
|
|
device: str = "cuda",
|
2024-12-17 21:26:17 +08:00
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.device = torch.device(device)
|
|
|
|
self.encoder = encoder
|
|
|
|
self.network = network
|
|
|
|
self.init_final = init_final
|
2024-12-29 22:35:21 +08:00
|
|
|
|
2024-12-27 07:38:46 +08:00
|
|
|
# Find the last Linear layer's output dimension
|
|
|
|
for layer in reversed(network.net):
|
|
|
|
if isinstance(layer, nn.Linear):
|
|
|
|
out_features = layer.out_features
|
|
|
|
break
|
2024-12-29 22:35:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
# Output layer
|
|
|
|
if init_final is not None:
|
2024-12-27 07:38:46 +08:00
|
|
|
self.output_layer = nn.Linear(out_features, 1)
|
2024-12-17 21:26:17 +08:00
|
|
|
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
|
|
|
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
|
|
|
else:
|
2024-12-27 07:38:46 +08:00
|
|
|
self.output_layer = nn.Linear(out_features, 1)
|
2024-12-17 21:26:17 +08:00
|
|
|
orthogonal_init()(self.output_layer.weight)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
self.to(self.device)
|
|
|
|
|
2024-12-29 20:51:21 +08:00
|
|
|
def forward(self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False) -> torch.Tensor:
|
2024-12-17 21:26:17 +08:00
|
|
|
self.train(train)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
observations = observations.to(self.device)
|
|
|
|
actions = actions.to(self.device)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-23 17:44:29 +08:00
|
|
|
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
inputs = torch.cat([obs_enc, actions], dim=-1)
|
|
|
|
x = self.network(inputs)
|
|
|
|
value = self.output_layer(x)
|
|
|
|
return value.squeeze(-1)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
|
|
|
|
class Policy(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
encoder: Optional[nn.Module],
|
|
|
|
network: nn.Module,
|
|
|
|
action_dim: int,
|
2024-12-29 20:30:39 +08:00
|
|
|
log_std_min: float = -5,
|
|
|
|
log_std_max: float = 2,
|
2024-12-17 21:26:17 +08:00
|
|
|
fixed_std: Optional[torch.Tensor] = None,
|
|
|
|
init_final: Optional[float] = None,
|
2024-12-29 20:30:39 +08:00
|
|
|
use_tanh_squash: bool = False,
|
2024-12-29 22:35:21 +08:00
|
|
|
device: str = "cuda",
|
2024-12-17 21:26:17 +08:00
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.device = torch.device(device)
|
|
|
|
self.encoder = encoder
|
|
|
|
self.network = network
|
|
|
|
self.action_dim = action_dim
|
2024-12-29 20:30:39 +08:00
|
|
|
self.log_std_min = log_std_min
|
|
|
|
self.log_std_max = log_std_max
|
2024-12-17 21:26:17 +08:00
|
|
|
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
2024-12-29 20:30:39 +08:00
|
|
|
self.use_tanh_squash = use_tanh_squash
|
2024-12-29 22:35:21 +08:00
|
|
|
|
2024-12-27 07:38:46 +08:00
|
|
|
# Find the last Linear layer's output dimension
|
|
|
|
for layer in reversed(network.net):
|
|
|
|
if isinstance(layer, nn.Linear):
|
|
|
|
out_features = layer.out_features
|
|
|
|
break
|
2024-12-29 22:35:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
# Mean layer
|
2024-12-27 07:38:46 +08:00
|
|
|
self.mean_layer = nn.Linear(out_features, action_dim)
|
2024-12-17 21:26:17 +08:00
|
|
|
if init_final is not None:
|
|
|
|
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
|
|
|
|
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
|
|
|
else:
|
|
|
|
orthogonal_init()(self.mean_layer.weight)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
# Standard deviation layer or parameter
|
|
|
|
if fixed_std is None:
|
2024-12-29 20:30:39 +08:00
|
|
|
self.std_layer = nn.Linear(out_features, action_dim)
|
|
|
|
if init_final is not None:
|
|
|
|
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
|
|
|
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
2024-12-17 21:26:17 +08:00
|
|
|
else:
|
2024-12-29 20:30:39 +08:00
|
|
|
orthogonal_init()(self.std_layer.weight)
|
2024-12-29 22:35:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
self.to(self.device)
|
|
|
|
|
|
|
|
def forward(
|
2024-12-29 20:51:21 +08:00
|
|
|
self,
|
2024-12-17 21:26:17 +08:00
|
|
|
observations: torch.Tensor,
|
2024-12-29 20:30:39 +08:00
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
2024-12-17 21:26:17 +08:00
|
|
|
# Encode observations if encoder exists
|
2024-12-29 22:35:21 +08:00
|
|
|
obs_enc = observations if self.encoder is not None else self.encoder(observations)
|
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
# Get network outputs
|
|
|
|
outputs = self.network(obs_enc)
|
|
|
|
means = self.mean_layer(outputs)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
# Compute standard deviations
|
|
|
|
if self.fixed_std is None:
|
2024-12-29 20:30:39 +08:00
|
|
|
log_std = self.std_layer(outputs)
|
|
|
|
if self.use_tanh_squash:
|
|
|
|
log_std = torch.tanh(log_std)
|
|
|
|
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
2024-12-17 21:26:17 +08:00
|
|
|
else:
|
|
|
|
stds = self.fixed_std.expand_as(means)
|
|
|
|
|
2024-12-29 20:30:39 +08:00
|
|
|
# uses tahn activation function to squash the action to be in the range of [-1, 1]
|
|
|
|
normal = torch.distributions.Normal(means, stds)
|
2024-12-29 22:35:21 +08:00
|
|
|
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
|
2024-12-29 20:30:39 +08:00
|
|
|
log_probs = normal.log_prob(x_t)
|
|
|
|
if self.use_tanh_squash:
|
|
|
|
actions = torch.tanh(x_t)
|
|
|
|
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
|
2024-12-29 22:35:21 +08:00
|
|
|
log_probs = log_probs.sum(-1) # sum over action dim
|
2024-12-29 20:30:39 +08:00
|
|
|
|
|
|
|
return actions, log_probs
|
2024-12-29 22:35:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
|
|
|
"""Get encoded features from observations"""
|
|
|
|
observations = observations.to(self.device)
|
|
|
|
if self.encoder is not None:
|
|
|
|
with torch.no_grad():
|
|
|
|
return self.encoder(observations, train=False)
|
|
|
|
return observations
|
|
|
|
|
|
|
|
|
2024-12-12 18:45:30 +08:00
|
|
|
class SACObservationEncoder(nn.Module):
|
2024-12-17 21:26:17 +08:00
|
|
|
"""Encode image and/or state vector observations.
|
|
|
|
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
|
|
|
|
"""
|
2024-12-12 18:45:30 +08:00
|
|
|
|
|
|
|
def __init__(self, config: SACConfig):
|
2024-12-17 21:26:17 +08:00
|
|
|
"""
|
|
|
|
Creates encoders for pixel and/or state modalities.
|
|
|
|
"""
|
2024-12-12 18:45:30 +08:00
|
|
|
super().__init__()
|
|
|
|
self.config = config
|
2024-12-17 21:26:17 +08:00
|
|
|
|
|
|
|
if "observation.image" in config.input_shapes:
|
|
|
|
self.image_enc_layers = nn.Sequential(
|
|
|
|
nn.Conv2d(
|
|
|
|
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
|
|
|
|
),
|
|
|
|
nn.ReLU(),
|
|
|
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
|
|
|
nn.ReLU(),
|
|
|
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
|
|
|
nn.ReLU(),
|
|
|
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
|
|
|
nn.ReLU(),
|
|
|
|
)
|
|
|
|
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
|
|
|
with torch.inference_mode():
|
|
|
|
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
|
|
|
self.image_enc_layers.extend(
|
|
|
|
nn.Sequential(
|
|
|
|
nn.Flatten(),
|
|
|
|
nn.Linear(np.prod(out_shape), config.latent_dim),
|
|
|
|
nn.LayerNorm(config.latent_dim),
|
|
|
|
nn.Tanh(),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
if "observation.state" in config.input_shapes:
|
|
|
|
self.state_enc_layers = nn.Sequential(
|
|
|
|
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
|
|
|
|
nn.ELU(),
|
|
|
|
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
|
|
|
nn.LayerNorm(config.latent_dim),
|
|
|
|
nn.Tanh(),
|
|
|
|
)
|
|
|
|
if "observation.environment_state" in config.input_shapes:
|
|
|
|
self.env_state_enc_layers = nn.Sequential(
|
|
|
|
nn.Linear(
|
|
|
|
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
|
|
|
),
|
|
|
|
nn.ELU(),
|
|
|
|
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
|
|
|
nn.LayerNorm(config.latent_dim),
|
|
|
|
nn.Tanh(),
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
|
|
|
"""Encode the image and/or state vector.
|
|
|
|
|
|
|
|
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
|
|
|
over all features.
|
|
|
|
"""
|
|
|
|
feat = []
|
|
|
|
# Concatenate all images along the channel dimension.
|
|
|
|
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
|
|
|
|
for image_key in image_keys:
|
|
|
|
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
|
|
|
|
if "observation.environment_state" in self.config.input_shapes:
|
|
|
|
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
|
|
|
if "observation.state" in self.config.input_shapes:
|
|
|
|
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
|
|
|
return torch.stack(feat, dim=0).mean(0)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
|
|
|
|
class LagrangeMultiplier(nn.Module):
|
2024-12-29 20:51:21 +08:00
|
|
|
def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"):
|
2024-12-17 21:26:17 +08:00
|
|
|
super().__init__()
|
|
|
|
self.device = torch.device(device)
|
|
|
|
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
# Initialize the Lagrange multiplier as a parameter
|
|
|
|
self.lagrange = nn.Parameter(
|
|
|
|
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
|
|
|
|
)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
self.to(self.device)
|
|
|
|
|
2024-12-29 20:51:21 +08:00
|
|
|
def forward(self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
# Get the multiplier value based on parameterization
|
2024-12-17 21:26:17 +08:00
|
|
|
multiplier = torch.nn.functional.softplus(self.lagrange)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
# Return the raw multiplier if no constraint values provided
|
|
|
|
if lhs is None:
|
|
|
|
return multiplier
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
# Move inputs to device
|
|
|
|
lhs = lhs.to(self.device)
|
|
|
|
if rhs is not None:
|
|
|
|
rhs = rhs.to(self.device)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
# Use the multiplier to compute the Lagrange penalty
|
|
|
|
if rhs is None:
|
|
|
|
rhs = torch.zeros_like(lhs, device=self.device)
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
diff = lhs - rhs
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
|
2024-12-29 20:51:21 +08:00
|
|
|
|
2024-12-17 21:26:17 +08:00
|
|
|
return multiplier * diff
|
|
|
|
|
|
|
|
|
2024-12-29 20:30:39 +08:00
|
|
|
def orthogonal_init():
|
|
|
|
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
2024-12-17 21:26:17 +08:00
|
|
|
|
|
|
|
|
2024-12-27 07:38:46 +08:00
|
|
|
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList:
|
2024-12-17 21:26:17 +08:00
|
|
|
"""Creates an ensemble of critic networks"""
|
2024-12-27 07:38:46 +08:00
|
|
|
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
|
|
|
return nn.ModuleList(critics).to(device)
|
2024-12-17 21:26:17 +08:00
|
|
|
|
|
|
|
|
|
|
|
# borrowed from tdmpc
|
|
|
|
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
|
|
|
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
|
|
|
(B, *), where * is any number of dimensions.
|
2024-12-29 20:51:21 +08:00
|
|
|
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
2024-12-17 21:26:17 +08:00
|
|
|
can be more than 1 dimensions, generally different from *.
|
|
|
|
Returns:
|
|
|
|
A return value from the callable reshaped to (**, *).
|
|
|
|
"""
|
|
|
|
if image_tensor.ndim == 4:
|
|
|
|
return fn(image_tensor)
|
|
|
|
start_dims = image_tensor.shape[:-3]
|
|
|
|
inp = torch.flatten(image_tensor, end_dim=-4)
|
|
|
|
flat_out = fn(inp)
|
|
|
|
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|