parent
c9af8e36a7
commit
def42ff487
|
@ -0,0 +1,39 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SACConfig:
|
||||||
|
discount = 0.99
|
||||||
|
temperature_init = 1.0
|
||||||
|
num_critics = 2
|
||||||
|
critic_lr = 3e-4
|
||||||
|
actor_lr = 3e-4
|
||||||
|
critic_network_kwargs = {
|
||||||
|
"hidden_dims": [256, 256],
|
||||||
|
"activate_final": True,
|
||||||
|
}
|
||||||
|
actor_network_kwargs = {
|
||||||
|
"hidden_dims": [256, 256],
|
||||||
|
"activate_final": True,
|
||||||
|
}
|
||||||
|
policy_kwargs = {
|
||||||
|
"tanh_squash_distribution": True,
|
||||||
|
"std_parameterization": "uniform",
|
||||||
|
}
|
|
@ -15,7 +15,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
# TODO: (1) better device management
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from copy import deepcopy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
|
||||||
|
@ -27,6 +31,10 @@ from torch import Tensor
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||||
|
import numpy as np
|
||||||
|
from typing import Callable, Optional, Tuple, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SACPolicy(
|
class SACPolicy(
|
||||||
nn.Module,
|
nn.Module,
|
||||||
|
@ -58,12 +66,27 @@ class SACPolicy(
|
||||||
self.unnormalize_outputs = Unnormalize(
|
self.unnormalize_outputs = Unnormalize(
|
||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
|
encoder = SACObservationEncoder(config)
|
||||||
|
# Define networks
|
||||||
|
critic_nets = []
|
||||||
|
for _ in range(config.num_critics):
|
||||||
|
critic_net = Critic(
|
||||||
|
encoder=encoder,
|
||||||
|
network=MLP(**config.critic_network_kwargs)
|
||||||
|
)
|
||||||
|
critic_nets.append(critic_net)
|
||||||
|
|
||||||
self.critic_ensemble = ...
|
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||||
self.critic_target = ...
|
self.critic_target = deepcopy(self.critic_ensemble)
|
||||||
self.actor_network = ...
|
|
||||||
|
|
||||||
self.temperature = ...
|
self.actor_network = Policy(
|
||||||
|
encoder=encoder,
|
||||||
|
network=MLP(**config.actor_network_kwargs),
|
||||||
|
action_dim=config.output_shapes["action"][0],
|
||||||
|
**config.policy_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
|
@ -178,10 +201,483 @@ class SACPolicy(
|
||||||
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
||||||
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: SACConfig,
|
||||||
|
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||||
|
activate_final: bool = False,
|
||||||
|
dropout_rate: Optional[float] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.activate_final = config.activate_final
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for i, size in enumerate(config.network_hidden_dims):
|
||||||
|
layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size))
|
||||||
|
|
||||||
|
if i + 1 < len(config.network_hidden_dims) or activate_final:
|
||||||
|
if dropout_rate is not None and dropout_rate > 0:
|
||||||
|
layers.append(nn.Dropout(p=dropout_rate))
|
||||||
|
layers.append(nn.LayerNorm(size))
|
||||||
|
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||||
|
|
||||||
|
self.net = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor:
|
||||||
|
# in training mode or not. TODO: find better way to do this
|
||||||
|
self.train(train)
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Critic(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder: Optional[nn.Module],
|
||||||
|
network: nn.Module,
|
||||||
|
init_final: Optional[float] = None,
|
||||||
|
activate_final: bool = False,
|
||||||
|
device: str = "cuda"
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.device = torch.device(device)
|
||||||
|
self.encoder = encoder
|
||||||
|
self.network = network
|
||||||
|
self.init_final = init_final
|
||||||
|
self.activate_final = activate_final
|
||||||
|
|
||||||
|
# Output layer
|
||||||
|
if init_final is not None:
|
||||||
|
if self.activate_final:
|
||||||
|
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
|
||||||
|
else:
|
||||||
|
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
|
||||||
|
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||||
|
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||||
|
else:
|
||||||
|
if self.activate_final:
|
||||||
|
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
|
||||||
|
else:
|
||||||
|
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
|
||||||
|
orthogonal_init()(self.output_layer.weight)
|
||||||
|
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
observations: torch.Tensor,
|
||||||
|
actions: torch.Tensor,
|
||||||
|
train: bool = False
|
||||||
|
) -> torch.Tensor:
|
||||||
|
self.train(train)
|
||||||
|
|
||||||
|
observations = observations.to(self.device)
|
||||||
|
actions = actions.to(self.device)
|
||||||
|
|
||||||
|
if self.encoder is not None:
|
||||||
|
obs_enc = self.encoder(observations)
|
||||||
|
else:
|
||||||
|
obs_enc = observations
|
||||||
|
|
||||||
|
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||||
|
x = self.network(inputs)
|
||||||
|
value = self.output_layer(x)
|
||||||
|
return value.squeeze(-1)
|
||||||
|
|
||||||
|
def q_value_ensemble(
|
||||||
|
self,
|
||||||
|
observations: torch.Tensor,
|
||||||
|
actions: torch.Tensor,
|
||||||
|
train: bool = False
|
||||||
|
) -> torch.Tensor:
|
||||||
|
observations = observations.to(self.device)
|
||||||
|
actions = actions.to(self.device)
|
||||||
|
|
||||||
|
if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
|
||||||
|
batch_size, num_actions = actions.shape[:2]
|
||||||
|
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
|
||||||
|
obs_flat = obs_expanded.reshape(-1, observations.shape[-1])
|
||||||
|
actions_flat = actions.reshape(-1, actions.shape[-1])
|
||||||
|
q_values = self(obs_flat, actions_flat, train)
|
||||||
|
return q_values.reshape(batch_size, num_actions)
|
||||||
|
else:
|
||||||
|
return self(observations, actions, train)
|
||||||
|
|
||||||
|
|
||||||
|
class Policy(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder: Optional[nn.Module],
|
||||||
|
network: nn.Module,
|
||||||
|
action_dim: int,
|
||||||
|
std_parameterization: str = "exp",
|
||||||
|
std_min: float = 1e-5,
|
||||||
|
std_max: float = 10.0,
|
||||||
|
tanh_squash_distribution: bool = False,
|
||||||
|
fixed_std: Optional[torch.Tensor] = None,
|
||||||
|
init_final: Optional[float] = None,
|
||||||
|
activate_final: bool = False,
|
||||||
|
device: str = "cuda"
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.device = torch.device(device)
|
||||||
|
self.encoder = encoder
|
||||||
|
self.network = network
|
||||||
|
self.action_dim = action_dim
|
||||||
|
self.std_parameterization = std_parameterization
|
||||||
|
self.std_min = std_min
|
||||||
|
self.std_max = std_max
|
||||||
|
self.tanh_squash_distribution = tanh_squash_distribution
|
||||||
|
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||||
|
self.activate_final = activate_final
|
||||||
|
|
||||||
|
# Mean layer
|
||||||
|
if self.activate_final:
|
||||||
|
self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim)
|
||||||
|
else:
|
||||||
|
self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim)
|
||||||
|
if init_final is not None:
|
||||||
|
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
|
||||||
|
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
||||||
|
else:
|
||||||
|
orthogonal_init()(self.mean_layer.weight)
|
||||||
|
|
||||||
|
# Standard deviation layer or parameter
|
||||||
|
if fixed_std is None:
|
||||||
|
if std_parameterization == "uniform":
|
||||||
|
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
|
||||||
|
else:
|
||||||
|
if self.activate_final:
|
||||||
|
self.std_layer = nn.Linear(network.net[-3].out_features, action_dim)
|
||||||
|
else:
|
||||||
|
self.std_layer = nn.Linear(network.net[-2].out_features, action_dim)
|
||||||
|
if init_final is not None:
|
||||||
|
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
||||||
|
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||||
|
else:
|
||||||
|
orthogonal_init()(self.std_layer.weight)
|
||||||
|
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
observations: torch.Tensor,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
train: bool = False,
|
||||||
|
non_squash_distribution: bool = False
|
||||||
|
) -> torch.distributions.Distribution:
|
||||||
|
self.train(train)
|
||||||
|
|
||||||
|
# Encode observations if encoder exists
|
||||||
|
if self.encoder is not None:
|
||||||
|
with torch.set_grad_enabled(train):
|
||||||
|
obs_enc = self.encoder(observations, train=train)
|
||||||
|
else:
|
||||||
|
obs_enc = observations
|
||||||
|
# Get network outputs
|
||||||
|
outputs = self.network(obs_enc)
|
||||||
|
means = self.mean_layer(outputs)
|
||||||
|
|
||||||
|
# Compute standard deviations
|
||||||
|
if self.fixed_std is None:
|
||||||
|
if self.std_parameterization == "exp":
|
||||||
|
log_stds = self.std_layer(outputs)
|
||||||
|
stds = torch.exp(log_stds)
|
||||||
|
elif self.std_parameterization == "softplus":
|
||||||
|
stds = torch.nn.functional.softplus(self.std_layer(outputs))
|
||||||
|
elif self.std_parameterization == "uniform":
|
||||||
|
stds = torch.exp(self.log_stds).expand_as(means)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid std_parameterization: {self.std_parameterization}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.std_parameterization == "fixed"
|
||||||
|
stds = self.fixed_std.expand_as(means)
|
||||||
|
|
||||||
|
# Clip standard deviations and scale with temperature
|
||||||
|
temperature = torch.tensor(temperature, device=self.device)
|
||||||
|
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature)
|
||||||
|
|
||||||
|
# Create distribution
|
||||||
|
if self.tanh_squash_distribution and not non_squash_distribution:
|
||||||
|
distribution = TanhMultivariateNormalDiag(
|
||||||
|
loc=means,
|
||||||
|
scale_diag=stds,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
distribution = torch.distributions.Normal(
|
||||||
|
loc=means,
|
||||||
|
scale=stds,
|
||||||
|
)
|
||||||
|
|
||||||
|
return distribution
|
||||||
|
|
||||||
|
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Get encoded features from observations"""
|
||||||
|
observations = observations.to(self.device)
|
||||||
|
if self.encoder is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
return self.encoder(observations, train=False)
|
||||||
|
return observations
|
||||||
|
|
||||||
|
|
||||||
class SACObservationEncoder(nn.Module):
|
class SACObservationEncoder(nn.Module):
|
||||||
"""Encode image and/or state vector observations."""
|
"""Encode image and/or state vector observations.
|
||||||
|
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, config: SACConfig):
|
def __init__(self, config: SACConfig):
|
||||||
|
"""
|
||||||
|
Creates encoders for pixel and/or state modalities.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
if "observation.image" in config.input_shapes:
|
||||||
|
self.image_enc_layers = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
|
||||||
|
),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||||
|
with torch.inference_mode():
|
||||||
|
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||||
|
self.image_enc_layers.extend(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(np.prod(out_shape), config.latent_dim),
|
||||||
|
nn.LayerNorm(config.latent_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if "observation.state" in config.input_shapes:
|
||||||
|
self.state_enc_layers = nn.Sequential(
|
||||||
|
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||||
|
nn.LayerNorm(config.latent_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
if "observation.environment_state" in config.input_shapes:
|
||||||
|
self.env_state_enc_layers = nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
||||||
|
),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||||
|
nn.LayerNorm(config.latent_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Encode the image and/or state vector.
|
||||||
|
|
||||||
|
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||||
|
over all features.
|
||||||
|
"""
|
||||||
|
feat = []
|
||||||
|
# Concatenate all images along the channel dimension.
|
||||||
|
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
|
||||||
|
for image_key in image_keys:
|
||||||
|
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
|
||||||
|
if "observation.environment_state" in self.config.input_shapes:
|
||||||
|
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||||
|
if "observation.state" in self.config.input_shapes:
|
||||||
|
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||||
|
return torch.stack(feat, dim=0).mean(0)
|
||||||
|
|
||||||
|
|
||||||
|
class LagrangeMultiplier(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
init_value: float = 1.0,
|
||||||
|
constraint_shape: Sequence[int] = (),
|
||||||
|
device: str = "cuda"
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.device = torch.device(device)
|
||||||
|
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
|
||||||
|
|
||||||
|
# Initialize the Lagrange multiplier as a parameter
|
||||||
|
self.lagrange = nn.Parameter(
|
||||||
|
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
lhs: Optional[torch.Tensor] = None,
|
||||||
|
rhs: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Get the multiplier value based on parameterization
|
||||||
|
multiplier = torch.nn.functional.softplus(self.lagrange)
|
||||||
|
|
||||||
|
# Return the raw multiplier if no constraint values provided
|
||||||
|
if lhs is None:
|
||||||
|
return multiplier
|
||||||
|
|
||||||
|
# Move inputs to device
|
||||||
|
lhs = lhs.to(self.device)
|
||||||
|
if rhs is not None:
|
||||||
|
rhs = rhs.to(self.device)
|
||||||
|
|
||||||
|
# Use the multiplier to compute the Lagrange penalty
|
||||||
|
if rhs is None:
|
||||||
|
rhs = torch.zeros_like(lhs, device=self.device)
|
||||||
|
|
||||||
|
diff = lhs - rhs
|
||||||
|
|
||||||
|
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
|
||||||
|
|
||||||
|
return multiplier * diff
|
||||||
|
|
||||||
|
|
||||||
|
# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where:
|
||||||
|
# 1. The base distribution is a diagonal multivariate normal distribution
|
||||||
|
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
|
||||||
|
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
|
||||||
|
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
|
||||||
|
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
loc: torch.Tensor,
|
||||||
|
scale_diag: torch.Tensor,
|
||||||
|
low: Optional[torch.Tensor] = None,
|
||||||
|
high: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
# Create base normal distribution
|
||||||
|
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)
|
||||||
|
|
||||||
|
# Create list of transforms
|
||||||
|
transforms = []
|
||||||
|
|
||||||
|
# Add tanh transform
|
||||||
|
transforms.append(torch.distributions.transforms.TanhTransform())
|
||||||
|
|
||||||
|
# Add rescaling transform if bounds are provided
|
||||||
|
if low is not None and high is not None:
|
||||||
|
transforms.append(
|
||||||
|
torch.distributions.transforms.AffineTransform(
|
||||||
|
loc=(high + low) / 2,
|
||||||
|
scale=(high - low) / 2
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize parent class
|
||||||
|
super().__init__(
|
||||||
|
base_distribution=base_distribution,
|
||||||
|
transforms=transforms
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store parameters
|
||||||
|
self.loc = loc
|
||||||
|
self.scale_diag = scale_diag
|
||||||
|
self.low = low
|
||||||
|
self.high = high
|
||||||
|
|
||||||
|
def mode(self) -> torch.Tensor:
|
||||||
|
"""Get the mode of the transformed distribution"""
|
||||||
|
# The mode of a normal distribution is its mean
|
||||||
|
mode = self.loc
|
||||||
|
|
||||||
|
# Apply transforms
|
||||||
|
for transform in self.transforms:
|
||||||
|
mode = transform(mode)
|
||||||
|
|
||||||
|
return mode
|
||||||
|
|
||||||
|
def rsample(self, sample_shape=torch.Size()) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Reparameterized sample from the distribution
|
||||||
|
"""
|
||||||
|
# Sample from base distribution
|
||||||
|
x = self.base_dist.rsample(sample_shape)
|
||||||
|
|
||||||
|
# Apply transforms
|
||||||
|
for transform in self.transforms:
|
||||||
|
x = transform(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute log probability of a value
|
||||||
|
Includes the log det jacobian for the transforms
|
||||||
|
"""
|
||||||
|
# Initialize log prob
|
||||||
|
log_prob = torch.zeros_like(value[..., 0])
|
||||||
|
|
||||||
|
# Inverse transforms to get back to normal distribution
|
||||||
|
q = value
|
||||||
|
for transform in reversed(self.transforms):
|
||||||
|
q = transform.inv(q)
|
||||||
|
log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q))
|
||||||
|
|
||||||
|
# Add base distribution log prob
|
||||||
|
log_prob = log_prob + self.base_dist.log_prob(q).sum(-1)
|
||||||
|
|
||||||
|
return log_prob
|
||||||
|
|
||||||
|
def sample_and_log_prob(self, sample_shape=torch.Size()) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Sample from the distribution and compute log probability
|
||||||
|
"""
|
||||||
|
x = self.rsample(sample_shape)
|
||||||
|
log_prob = self.log_prob(x)
|
||||||
|
return x, log_prob
|
||||||
|
|
||||||
|
def entropy(self) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute entropy of the distribution
|
||||||
|
"""
|
||||||
|
# Start with base distribution entropy
|
||||||
|
entropy = self.base_dist.entropy().sum(-1)
|
||||||
|
|
||||||
|
# Add log det jacobian for each transform
|
||||||
|
x = self.rsample()
|
||||||
|
for transform in self.transforms:
|
||||||
|
entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
|
||||||
|
x = transform(x)
|
||||||
|
|
||||||
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
|
def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList:
|
||||||
|
"""Creates an ensemble of critic networks"""
|
||||||
|
critics = nn.ModuleList([critic_class() for _ in range(num_critics)])
|
||||||
|
return critics.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def orthogonal_init():
|
||||||
|
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
# borrowed from tdmpc
|
||||||
|
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||||
|
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||||
|
(B, *), where * is any number of dimensions.
|
||||||
|
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||||
|
can be more than 1 dimensions, generally different from *.
|
||||||
|
Returns:
|
||||||
|
A return value from the callable reshaped to (**, *).
|
||||||
|
"""
|
||||||
|
if image_tensor.ndim == 4:
|
||||||
|
return fn(image_tensor)
|
||||||
|
start_dims = image_tensor.shape[:-3]
|
||||||
|
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||||
|
flat_out = fn(inp)
|
||||||
|
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue