fixes and updated comments

This commit is contained in:
Michel Aractingi 2024-11-26 09:46:59 +00:00
parent 15090c2544
commit 16edbbdeee
3 changed files with 223 additions and 202 deletions

View File

@ -19,7 +19,7 @@ from dataclasses import dataclass, field
@dataclass
class TDMPC2Config:
"""Configuration class for TDMPCPolicy.
"""Configuration class for TDMPC2Policy.
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
camera observations.
@ -77,18 +77,9 @@ class TDMPC2Config:
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
is applied. Note that the input images are assumed to be square for this augmentation.
reward_coeff: Loss weighting coefficient for the reward regression loss.
expectile_weight: Weighting (τ) used in expectile regression for the state value function (V).
v_pred < v_target is weighted by τ and v_pred >= v_target is weighted by (1-τ). τ is expected to
be in [0, 1]. Setting τ closer to 1 results in a more "optimistic" V. This is sensible to do
because v_target is obtained by evaluating the learned state-action value functions (Q) with
in-sample actions that may not be always optimal.
value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state
value (V) expectile regression loss.
consistency_coeff: Loss weighting coefficient for the consistency loss.
advantage_scaling: A factor by which the advantages are scaled prior to exponentiation for advantage
weighted regression of the policy (π) estimator parameters. Note that the exponentiated advantages
are clamped at 100.0.
pi_coeff: Loss weighting coefficient for the action regression loss.
temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time-
steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the
current time step.
@ -126,9 +117,12 @@ class TDMPC2Config:
state_encoder_hidden_dim: int = 256
latent_dim: int = 512
q_ensemble_size: int = 5
num_enc_layers: int = 2
mlp_dim: int = 512
# Reinforcement learning.
discount: float = 0.9
simnorm_dim: int = 8
dropout: float = 0.01
# actor
log_std_min: float = -10
@ -157,10 +151,10 @@ class TDMPC2Config:
consistency_coeff: float = 20.0
entropy_coef: float = 1e-4
temporal_decay_coeff: float = 0.5
# Target model. NOTE (michel_aractingi) this is equivelant to
# 1 - target_model_momentum of our TD-MPC1 implementation because
# Target model. NOTE (michel_aractingi) this is equivelant to
# 1 - target_model_momentum of our TD-MPC1 implementation because
# of the use of `torch.lerp`
target_model_momentum: float = 0.01
target_model_momentum: float = 0.01
def __post_init__(self):
"""Input validation (not exhaustive)."""

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 Nicklas Hansen and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,11 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of Finetuning Offline World Models in the Real World.
"""Implementation of TD-MPC2: Scalable, Robust World Models for Continuous Control
The comments in this code may sometimes refer to these references:
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955)
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029)
We refer to the main paper and codebase:
TD-MPC2 paper: (https://arxiv.org/abs/2310.16828)
TD-MPC2 code: (https://github.com/nicklashansen/tdmpc2)
"""
# ruff: noqa: N806
@ -38,8 +38,16 @@ from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.tdmpc2.tdmpc2_utils import (
NormedLinear,
SimNorm,
gaussian_logprob,
soft_cross_entropy,
squash,
two_hot_inv,
)
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash, soft_cross_entropy
class TDMPC2Policy(
nn.Module,
@ -48,22 +56,7 @@ class TDMPC2Policy(
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "tdmpc2"],
):
"""Implementation of TD-MPC2 learning + inference.
Please note several warnings for this policy.
- Evaluation of pretrained weights created with the original FOWM code
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
process communication to use the xarm environment from FOWM. This is because our xarm
environment uses newer dependencies and does not match the environment in FOWM. See
https://github.com/huggingface/lerobot/pull/103 for implementation details.
- We have NOT checked that training on LeRobot reproduces the results from FOWM.
- Nevertheless, we have verified that we can train TD-MPC for PushT. See
`lerobot/configs/policy/tdmpc2_pusht_keypoints.yaml`.
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
match our xarm environment.
"""
"""Implementation of TD-MPC2 learning + inference."""
name = "tdmpc2"
@ -83,6 +76,8 @@ class TDMPC2Policy(
config = TDMPC2Config()
self.config = config
self.model = TDMPC2WorldModel(config)
# TODO (michel-aractingi) temp fix for gpu
self.model = self.model.to("cuda:0")
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
@ -109,7 +104,9 @@ class TDMPC2Policy(
self._use_env_state = True
self.scale = RunningScale(self.config.target_model_momentum)
self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length
self.discount = (
self.config.discount
) # TODO (michel-aractingi) downscale discount according to episode length
self.reset()
@ -204,7 +201,7 @@ class TDMPC2Policy(
for t in range(self.config.horizon):
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
# helpful for CEM.
pi_actions[t] = self.model.pi(_z, self.config.min_std)[0]
pi_actions[t] = self.model.pi(_z)[0]
_z = self.model.latent_dynamics(_z, pi_actions[t])
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
@ -249,14 +246,17 @@ class TDMPC2Policy(
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim)
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (score.sum(0) + 1e-9)
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (
einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9
)
std = torch.sqrt(
torch.sum(
einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2,
dim=1,
) / (score.sum(0) + 1e-9)
).clamp_(self.config.min_std, self.config.max_std)
)
/ (einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9)
).clamp_(self.config.min_std, self.config.max_std)
# Keep track of the mean for warm-starting subsequent steps.
self._prev_mean = mean
@ -286,11 +286,11 @@ class TDMPC2Policy(
# Update the return and running discount.
G += running_discount * reward
running_discount *= self.config.discount
#next_action = self.model.pi(z)[0] # (batch, action_dim)
#terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type='avg')
# next_action = self.model.pi(z)[0] # (batch, action_dim)
# terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type="avg")
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
@ -351,14 +351,15 @@ class TDMPC2Policy(
# Compute various targets with stopgrad.
with torch.no_grad():
# Latent state consistency targets for consistency loss.
# Latent state consistency targets for consistency loss.
z_targets = self.model.encode(next_observations)
# Compute the TD-target from a reward and the next observation
pi = self.model.pi(z_targets)[0]
td_targets = (
reward
+ self.config.discount * self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
+ self.config.discount
* self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
)
# Compute losses.
@ -421,18 +422,20 @@ class TDMPC2Policy(
z_preds = z_preds.detach()
self.model.change_q_grad(mode=False)
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
with torch.no_grad():
# avoid unnessecary computation of the gradients during policy optimization
# TODO (michel-aractingi): the same logic should be extended when adding task embeddings
# TODO (michel-aractingi): the same logic should be extended when adding task embeddings
qs = self.model.Qs(z_preds[:-1], action_preds, return_type="avg")
self.scale.update(qs[0])
qs = self.scale(qs)
rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(-1)
rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(
-1
)
pi_loss = (
(self.config.entropy_coef * log_pis - qs).mean(dim=(1,2))
(self.config.entropy_coef * log_pis - qs).mean(dim=(1, 2))
* rho
# * temporal_loss_coeffs
# `action_preds` depends on the first observation and the actions.
@ -470,56 +473,69 @@ class TDMPC2Policy(
"""Update the target model's using polyak averaging."""
self.model.update_target_Q()
class TDMPC2WorldModel(nn.Module):
"""Latent dynamics model used in TD-MPC2."""
def __init__(self, config: TDMPC2Config):
super().__init__()
self.config = config
self._encoder = TDMPC2ObservationEncoder(config)
# Define latent dynamics head
self._dynamics = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim),
NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
self._dynamics = nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim),
NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)),
)
# Define reward head
self._reward = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, max(config.num_bins, 1)))
self._reward = nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
)
# Define policy head
self._pi = nn.Sequential(NormedLinear(config.latent_dim, config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0]))
self._pi = nn.Sequential(
NormedLinear(config.latent_dim, config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0]),
)
# Define ensemble of Q functions
self._Qs = nn.ModuleList(
[
nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim, dropout=config.dropout),
NormedLinear(
config.latent_dim + config.output_shapes["action"][0],
config.mlp_dim,
dropout=config.dropout,
),
NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, max(config.num_bins, 1))
) for _ in range(config.q_ensemble_size)
nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
)
for _ in range(config.q_ensemble_size)
]
)
self._init_weights()
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
self.log_std_min = torch.tensor(config.log_std_min)
self.log_std_dif = torch.tensor(config.log_std_max) - self.log_std_min
self.bins = torch.linspace(config.vmin, config.vmax, config.num_bins)
self.config.bin_size = (config.vmax - config.vmin) / (config.num_bins - 1)
def _init_weights(self):
"""Initialize model weights.
Custom weight initializations proposed in TD-MPC2.
Custom weight initializations proposed in TD-MPC2.
"""
def _apply_fn(m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
@ -527,13 +543,13 @@ class TDMPC2WorldModel(nn.Module):
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.ParameterList):
for i, p in enumerate(m):
if p.dim() == 3: # Linear
nn.init.trunc_normal_(p, std=0.02) # Weight
nn.init.constant_(m[i+1], 0) # Bias
if p.dim() == 3: # Linear
nn.init.trunc_normal_(p, std=0.02) # Weight
nn.init.constant_(m[i + 1], 0) # Bias
self.apply(_apply_fn)
# initialize parameters of the
# initialize parameters of the
for m in [self._reward, *self._Qs]:
assert isinstance(
m[-1], nn.Linear
@ -549,7 +565,7 @@ class TDMPC2WorldModel(nn.Module):
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
self.bins = self.bins.to(*args, **kwargs)
return self
def train(self, mode):
super().train(mode)
self._target_Qs.train(False)
@ -641,7 +657,7 @@ class TDMPC2WorldModel(nn.Module):
Soft-update target Q-networks using Polyak averaging.
"""
with torch.no_grad():
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters(), strict=False):
p_target.data.lerp_(p.data, self.config.target_model_momentum)
@ -662,49 +678,51 @@ class TDMPC2ObservationEncoder(nn.Module):
for obs_key in config.input_shapes:
if "observation.image" in config.input_shapes:
encoder_module = nn.Sequential(
nn.Conv2d(config.input_shapes[obs_key][0], config.image_encoder_hidden_dim, 7, stride=2),
nn.Conv2d(config.input_shapes[obs_key][0], config.image_encoder_hidden_dim, 7, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=1),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=1),
)
dummy_batch = torch.zeros(1, *config.input_shapes[obs_key])
with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
out_shape = encoder_module(dummy_batch).shape[1:]
encoder_module.extend(
nn.Sequential(
nn.Flatten(),
NormedLinear(np.prod(out_shape), config.latent_dim, act=SimNorm(config.simnorm_dim)),
)
)
elif "observation.state" in config.input_shapes:
elif (
"observation.state" in config.input_shapes
or "observation.environment_state" in config.input_shapes
):
encoder_module = nn.ModuleList()
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim))
encoder_module.append(
NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)
)
assert config.num_enc_layers > 0
for _ in range(config.num_enc_layers - 1):
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
encoder_module = nn.Sequential(*encoder_module)
elif "observation.environment_state" in config.input_shapes:
encoder_module = nn.ModuleList()
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim))
assert config.num_enc_layers > 0
for _ in range(config.num_enc_layers - 1):
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
encoder_module.append(
NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)
)
encoder_module.append(
NormedLinear(
config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)
)
)
encoder_module = nn.Sequential(*encoder_module)
else:
raise NotImplementedError(f"No corresponding encoder module for key {obs_key}.")
encoder_dict[obs_key] = encoder_module
encoder_dict[obs_key.replace(".", "")] = encoder_module
self.encoder = nn.ModuleDict(encoder_dict)
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
@ -714,9 +732,11 @@ class TDMPC2ObservationEncoder(nn.Module):
feat = []
for obs_key in self.config.input_shapes:
if "observation.image" in obs_key:
feat.append(flatten_forward_unflatten(self.encoder[obs_key], obs_dict[obs_key]))
feat.append(
flatten_forward_unflatten(self.encoder[obs_key.replace(".", "")], obs_dict[obs_key])
)
else:
feat.append(self.encoder[obs_key](obs_dict[obs_key]))
feat.append(self.encoder[obs_key.replace(".", "")](obs_dict[obs_key]))
return torch.stack(feat, dim=0).mean(0)
@ -816,4 +836,4 @@ class RunningScale:
return x * (1 / self.value)
def __repr__(self):
return f"RunningScale(S: {self.value})"
return f"RunningScale(S: {self.value})"

View File

@ -5,152 +5,159 @@ from functorch import combine_state_for_ensemble
class Ensemble(nn.Module):
"""
Vectorized ensemble of modules.
"""
"""
Vectorized ensemble of modules.
"""
def __init__(self, modules, **kwargs):
super().__init__()
modules = nn.ModuleList(modules)
fn, params, _ = combine_state_for_ensemble(modules)
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs)
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
self._repr = str(modules)
def __init__(self, modules, **kwargs):
super().__init__()
modules = nn.ModuleList(modules)
fn, params, _ = combine_state_for_ensemble(modules)
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness="different", **kwargs)
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
self._repr = str(modules)
def forward(self, *args, **kwargs):
return self.vmap([p for p in self.params], (), *args, **kwargs)
def forward(self, *args, **kwargs):
return self.vmap([p for p in self.params], (), *args, **kwargs)
def __repr__(self):
return "Vectorized " + self._repr
def __repr__(self):
return 'Vectorized ' + self._repr
class SimNorm(nn.Module):
"""
Simplicial normalization.
Adapted from https://arxiv.org/abs/2204.00616.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
shp = x.shape
x = x.view(*shp[:-1], -1, self.dim)
x = F.softmax(x, dim=-1)
return x.view(*shp)
def __repr__(self):
return f"SimNorm(dim={self.dim})"
"""
Simplicial normalization.
Adapted from https://arxiv.org/abs/2204.00616.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
shp = x.shape
x = x.view(*shp[:-1], -1, self.dim)
x = F.softmax(x, dim=-1)
return x.view(*shp)
def __repr__(self):
return f"SimNorm(dim={self.dim})"
class NormedLinear(nn.Linear):
"""
Linear layer with LayerNorm, activation, and optionally dropout.
"""
"""
Linear layer with LayerNorm, activation, and optionally dropout.
"""
def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs):
super().__init__(*args, **kwargs)
self.ln = nn.LayerNorm(self.out_features)
self.act = act
self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None
def __init__(self, *args, dropout=0.0, act=nn.Mish(inplace=True), **kwargs):
super().__init__(*args, **kwargs)
self.ln = nn.LayerNorm(self.out_features)
self.act = act
self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None
def forward(self, x):
x = super().forward(x)
if self.dropout:
x = self.dropout(x)
return self.act(self.ln(x))
def __repr__(self):
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
return f"NormedLinear(in_features={self.in_features}, "\
f"out_features={self.out_features}, "\
f"bias={self.bias is not None}{repr_dropout}, "\
f"act={self.act.__class__.__name__})"
def forward(self, x):
x = super().forward(x)
if self.dropout:
x = self.dropout(x)
return self.act(self.ln(x))
def __repr__(self):
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
return (
f"NormedLinear(in_features={self.in_features}, "
f"out_features={self.out_features}, "
f"bias={self.bias is not None}{repr_dropout}, "
f"act={self.act.__class__.__name__})"
)
def soft_cross_entropy(pred, target, cfg):
"""Computes the cross entropy loss between predictions and soft targets."""
pred = F.log_softmax(pred, dim=-1)
target = two_hot(target, cfg)
return -(target * pred).sum(-1, keepdim=True)
"""Computes the cross entropy loss between predictions and soft targets."""
pred = F.log_softmax(pred, dim=-1)
target = two_hot(target, cfg)
import pudb
pudb.set_trace()
return -(target * pred).sum(-1, keepdim=True)
@torch.jit.script
def log_std(x, low, dif):
return low + 0.5 * dif * (torch.tanh(x) + 1)
return low + 0.5 * dif * (torch.tanh(x) + 1)
@torch.jit.script
def _gaussian_residual(eps, log_std):
return -0.5 * eps.pow(2) - log_std
return -0.5 * eps.pow(2) - log_std
@torch.jit.script
def _gaussian_logprob(residual):
return residual - 0.5 * torch.log(2 * torch.pi)
return residual - 0.5 * torch.log(2 * torch.pi)
def gaussian_logprob(eps, log_std, size=None):
"""Compute Gaussian log probability."""
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
if size is None:
size = eps.size(-1)
return _gaussian_logprob(residual) * size
"""Compute Gaussian log probability."""
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
if size is None:
size = eps.size(-1)
return _gaussian_logprob(residual) * size
@torch.jit.script
def _squash(pi):
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
def squash(mu, pi, log_pi):
"""Apply squashing function."""
mu = torch.tanh(mu)
pi = torch.tanh(pi)
log_pi -= _squash(pi).sum(-1, keepdim=True)
return mu, pi, log_pi
"""Apply squashing function."""
mu = torch.tanh(mu)
pi = torch.tanh(pi)
log_pi -= _squash(pi).sum(-1, keepdim=True)
return mu, pi, log_pi
@torch.jit.script
def symlog(x):
"""
Symmetric logarithmic function.
Adapted from https://github.com/danijar/dreamerv3.
"""
return torch.sign(x) * torch.log(1 + torch.abs(x))
"""
Symmetric logarithmic function.
Adapted from https://github.com/danijar/dreamerv3.
"""
return torch.sign(x) * torch.log(1 + torch.abs(x))
@torch.jit.script
def symexp(x):
"""
Symmetric exponential function.
Adapted from https://github.com/danijar/dreamerv3.
"""
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
"""
Symmetric exponential function.
Adapted from https://github.com/danijar/dreamerv3.
"""
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
def two_hot(x, cfg):
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
if cfg.num_bins == 0:
return x
elif cfg.num_bins == 1:
return symlog(x)
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1)
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset)
soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset)
return soft_two_hot
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
if cfg.num_bins == 0:
return x
elif cfg.num_bins == 1:
return symlog(x)
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
bin_offset = (x - cfg.vmin) / cfg.bin_size - bin_idx.float()
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
soft_two_hot.scatter_(1, bin_idx, 1 - bin_offset)
soft_two_hot.scatter_(1, (bin_idx + 1) % cfg.num_bins, bin_offset)
return soft_two_hot
def two_hot_inv(x, bins):
"""Converts a batch of soft two-hot encoded vectors to scalars."""
num_bins = bins.shape[0]
if num_bins == 0:
return x
elif num_bins == 1:
return symexp(x)
"""Converts a batch of soft two-hot encoded vectors to scalars."""
num_bins = bins.shape[0]
if num_bins == 0:
return x
elif num_bins == 1:
return symexp(x)
x = F.softmax(x, dim=-1)
x = torch.sum(x * bins, dim=-1, keepdim=True)
return symexp(x)
x = F.softmax(x, dim=-1)
x = torch.sum(x * bins, dim=-1, keepdim=True)
return symexp(x)