added new implementation of tdmpc2
This commit is contained in:
parent
963738d983
commit
a146544765
|
@ -0,0 +1,208 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TDMPC2Config:
|
||||||
|
"""Configuration class for TDMPCPolicy.
|
||||||
|
|
||||||
|
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
|
||||||
|
camera observations.
|
||||||
|
|
||||||
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
|
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||||
|
action repeats in Q-learning or ask your favorite chatbot)
|
||||||
|
horizon: Horizon for model predictive control.
|
||||||
|
n_action_steps: Number of action steps to take from the plan given by model predictive control. This
|
||||||
|
is an alternative to using action repeats. If this is set to more than 1, then we require
|
||||||
|
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
|
||||||
|
approach of using multiple steps from the plan is not in the original implementation.
|
||||||
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
|
include batch dimension or temporal dimension.
|
||||||
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
|
||||||
|
match the original implementation.
|
||||||
|
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||||
|
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
|
||||||
|
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
|
||||||
|
normalization mode here.
|
||||||
|
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
|
||||||
|
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
|
||||||
|
latent_dim: Observation's latent embedding dimension.
|
||||||
|
q_ensemble_size: Number of Q function estimators to use in an ensemble for uncertainty estimation.
|
||||||
|
mlp_dim: Hidden dimension of MLPs used for modelling the dynamics encoder, reward function, policy
|
||||||
|
(π), Q ensemble, and V.
|
||||||
|
discount: Discount factor (γ) to use for the reinforcement learning formalism.
|
||||||
|
use_mpc: Whether to use model predictive control. The alternative is to just sample the policy model
|
||||||
|
(π) for each step.
|
||||||
|
cem_iterations: Number of iterations for the MPPI/CEM loop in MPC.
|
||||||
|
max_std: Maximum standard deviation for actions sampled from the gaussian PDF in CEM.
|
||||||
|
min_std: Minimum standard deviation for noise applied to actions sampled from the policy model (π).
|
||||||
|
Doubles up as the minimum standard deviation for actions sampled from the gaussian PDF in CEM.
|
||||||
|
n_gaussian_samples: Number of samples to draw from the gaussian distribution every CEM iteration. Must
|
||||||
|
be non-zero.
|
||||||
|
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
|
||||||
|
be zero.
|
||||||
|
uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating
|
||||||
|
trajectory values (this is the λ coeffiecient in eqn 4 of FOWM).
|
||||||
|
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
|
||||||
|
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
||||||
|
elites, when updating the gaussian parameters for CEM.
|
||||||
|
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
|
||||||
|
parameters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
|
||||||
|
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
|
||||||
|
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
|
||||||
|
is applied. Note that the input images are assumed to be square for this augmentation.
|
||||||
|
reward_coeff: Loss weighting coefficient for the reward regression loss.
|
||||||
|
expectile_weight: Weighting (τ) used in expectile regression for the state value function (V).
|
||||||
|
v_pred < v_target is weighted by τ and v_pred >= v_target is weighted by (1-τ). τ is expected to
|
||||||
|
be in [0, 1]. Setting τ closer to 1 results in a more "optimistic" V. This is sensible to do
|
||||||
|
because v_target is obtained by evaluating the learned state-action value functions (Q) with
|
||||||
|
in-sample actions that may not be always optimal.
|
||||||
|
value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state
|
||||||
|
value (V) expectile regression loss.
|
||||||
|
consistency_coeff: Loss weighting coefficient for the consistency loss.
|
||||||
|
advantage_scaling: A factor by which the advantages are scaled prior to exponentiation for advantage
|
||||||
|
weighted regression of the policy (π) estimator parameters. Note that the exponentiated advantages
|
||||||
|
are clamped at 100.0.
|
||||||
|
pi_coeff: Loss weighting coefficient for the action regression loss.
|
||||||
|
temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time-
|
||||||
|
steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the
|
||||||
|
current time step.
|
||||||
|
target_model_momentum: Momentum (α) used for EMA updates of the target models. Updates are calculated
|
||||||
|
as ϕ ← αϕ + (1-α)θ where ϕ are the parameters of the target model and θ are the parameters of the
|
||||||
|
model being trained.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_action_repeats: int = 2
|
||||||
|
horizon: int = 5
|
||||||
|
n_action_steps: int = 1
|
||||||
|
|
||||||
|
input_shapes: dict[str, list[int]] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"observation.image": [3, 84, 84],
|
||||||
|
"observation.state": [4],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
output_shapes: dict[str, list[int]] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": [4],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes: dict[str, str] | None = None
|
||||||
|
output_normalization_modes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {"action": "min_max"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Architecture / modeling.
|
||||||
|
# Neural networks.
|
||||||
|
image_encoder_hidden_dim: int = 32
|
||||||
|
state_encoder_hidden_dim: int = 256
|
||||||
|
latent_dim: int = 50
|
||||||
|
q_ensemble_size: int = 5
|
||||||
|
mlp_dim: int = 512
|
||||||
|
# Reinforcement learning.
|
||||||
|
discount: float = 0.9
|
||||||
|
|
||||||
|
# actor
|
||||||
|
log_std_min: float = -10
|
||||||
|
log_std_max: float = 2
|
||||||
|
entropy_coef: float = 1e-4
|
||||||
|
|
||||||
|
# critic
|
||||||
|
num_bins: int = 101
|
||||||
|
vmin: int = -10
|
||||||
|
vmax: int = +10
|
||||||
|
|
||||||
|
rho: float = 0.5
|
||||||
|
tau: float = 0.01
|
||||||
|
# Inference.
|
||||||
|
use_mpc: bool = True
|
||||||
|
cem_iterations: int = 6
|
||||||
|
max_std: float = 2.0
|
||||||
|
min_std: float = 0.05
|
||||||
|
n_gaussian_samples: int = 512
|
||||||
|
n_pi_samples: int = 51
|
||||||
|
uncertainty_regularizer_coeff: float = 1.0
|
||||||
|
n_elites: int = 50
|
||||||
|
elite_weighting_temperature: float = 0.5
|
||||||
|
gaussian_mean_momentum: float = 0.1
|
||||||
|
|
||||||
|
# Training and loss computation.
|
||||||
|
max_random_shift_ratio: float = 0.0476
|
||||||
|
# Loss coefficients.
|
||||||
|
reward_coeff: float = 0.1
|
||||||
|
expectile_weight: float = 0.9
|
||||||
|
value_coeff: float = 0.1
|
||||||
|
consistency_coeff: float = 20.0
|
||||||
|
advantage_scaling: float = 3.0
|
||||||
|
pi_coeff: float = 0.5
|
||||||
|
temporal_decay_coeff: float = 0.5
|
||||||
|
# Target model.
|
||||||
|
target_model_momentum: float = 0.995
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Input validation (not exhaustive)."""
|
||||||
|
# There should only be one image key.
|
||||||
|
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||||
|
if len(image_keys) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}."
|
||||||
|
)
|
||||||
|
if len(image_keys) > 0:
|
||||||
|
image_key = next(iter(image_keys))
|
||||||
|
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||||
|
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||||
|
# augmentation. It should be able to be removed.
|
||||||
|
raise ValueError(
|
||||||
|
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||||
|
)
|
||||||
|
if self.n_gaussian_samples <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||||
|
)
|
||||||
|
if self.output_normalization_modes != {"action": "min_max"}:
|
||||||
|
raise ValueError(
|
||||||
|
"TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
|
||||||
|
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
|
||||||
|
"information."
|
||||||
|
)
|
||||||
|
if self.n_action_steps > 1:
|
||||||
|
if self.n_action_repeats != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
|
||||||
|
)
|
||||||
|
if not self.use_mpc:
|
||||||
|
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
|
||||||
|
if self.n_action_steps > self.horizon:
|
||||||
|
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
|
@ -0,0 +1,860 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Implementation of Finetuning Offline World Models in the Real World.
|
||||||
|
|
||||||
|
The comments in this code may sometimes refer to these references:
|
||||||
|
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955)
|
||||||
|
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ruff: noqa: N806
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from copy import deepcopy
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
|
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
|
||||||
|
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||||
|
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv
|
||||||
|
|
||||||
|
|
||||||
|
class TDMPC2Policy(
|
||||||
|
nn.Module,
|
||||||
|
PyTorchModelHubMixin,
|
||||||
|
library_name="lerobot",
|
||||||
|
repo_url="https://github.com/huggingface/lerobot",
|
||||||
|
tags=["robotics", "tdmpc2"],
|
||||||
|
):
|
||||||
|
"""Implementation of TD-MPC2 learning + inference.
|
||||||
|
|
||||||
|
Please note several warnings for this policy.
|
||||||
|
- Evaluation of pretrained weights created with the original FOWM code
|
||||||
|
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
|
||||||
|
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
|
||||||
|
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
|
||||||
|
process communication to use the xarm environment from FOWM. This is because our xarm
|
||||||
|
environment uses newer dependencies and does not match the environment in FOWM. See
|
||||||
|
https://github.com/huggingface/lerobot/pull/103 for implementation details.
|
||||||
|
- We have NOT checked that training on LeRobot reproduces the results from FOWM.
|
||||||
|
- Nevertheless, we have verified that we can train TD-MPC for PushT. See
|
||||||
|
`lerobot/configs/policy/tdmpc2_pusht_keypoints.yaml`.
|
||||||
|
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
|
||||||
|
match our xarm environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "tdmpc2"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, config: TDMPC2Config | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||||
|
the configuration class is used.
|
||||||
|
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||||
|
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = TDMPC2Config()
|
||||||
|
self.config = config
|
||||||
|
self.model = TDMPC2WorldModel(config)
|
||||||
|
self.model_target = deepcopy(self.model)
|
||||||
|
for param in self.model_target.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
if config.input_normalization_modes is not None:
|
||||||
|
self.normalize_inputs = Normalize(
|
||||||
|
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.normalize_inputs = nn.Identity()
|
||||||
|
self.normalize_targets = Normalize(
|
||||||
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
|
)
|
||||||
|
self.unnormalize_outputs = Unnormalize(
|
||||||
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
|
)
|
||||||
|
|
||||||
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
|
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||||
|
self._use_image = False
|
||||||
|
self._use_env_state = False
|
||||||
|
if len(image_keys) > 0:
|
||||||
|
assert len(image_keys) == 1
|
||||||
|
self._use_image = True
|
||||||
|
self.input_image_key = image_keys[0]
|
||||||
|
if "observation.environment_state" in config.input_shapes:
|
||||||
|
self._use_env_state = True
|
||||||
|
|
||||||
|
self.scale = RunningScale(self.config.tau)
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be
|
||||||
|
called on `env.reset()`
|
||||||
|
"""
|
||||||
|
self._queues = {
|
||||||
|
"observation.state": deque(maxlen=1),
|
||||||
|
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||||
|
}
|
||||||
|
if self._use_image:
|
||||||
|
self._queues["observation.image"] = deque(maxlen=1)
|
||||||
|
if self._use_env_state:
|
||||||
|
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||||
|
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||||
|
# CEM for the next step.
|
||||||
|
self._prev_mean: torch.Tensor | None = None
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Select a single action given environment observations."""
|
||||||
|
batch = self.normalize_inputs(batch)
|
||||||
|
if self._use_image:
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
|
|
||||||
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
|
# When the action queue is depleted, populate it again by querying the policy.
|
||||||
|
if len(self._queues["action"]) == 0:
|
||||||
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||||
|
|
||||||
|
# Remove the time dimensions as it is not handled yet.
|
||||||
|
for key in batch:
|
||||||
|
assert batch[key].shape[1] == 1
|
||||||
|
batch[key] = batch[key][:, 0]
|
||||||
|
|
||||||
|
# NOTE: Order of observations matters here.
|
||||||
|
encode_keys = []
|
||||||
|
if self._use_image:
|
||||||
|
encode_keys.append("observation.image")
|
||||||
|
if self._use_env_state:
|
||||||
|
encode_keys.append("observation.environment_state")
|
||||||
|
encode_keys.append("observation.state")
|
||||||
|
z = self.model.encode({k: batch[k] for k in encode_keys})
|
||||||
|
if self.config.use_mpc: # noqa: SIM108
|
||||||
|
actions = self.plan(z) # (horizon, batch, action_dim)
|
||||||
|
else:
|
||||||
|
# Plan with the policy (π) alone. This always returns one action so unsqueeze to get a
|
||||||
|
# sequence dimension like in the MPC branch.
|
||||||
|
actions = self.model.pi(z)[0].unsqueeze(0)
|
||||||
|
|
||||||
|
actions = torch.clamp(actions, -1, +1)
|
||||||
|
|
||||||
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
|
||||||
|
if self.config.n_action_repeats > 1:
|
||||||
|
for _ in range(self.config.n_action_repeats):
|
||||||
|
self._queues["action"].append(actions[0])
|
||||||
|
else:
|
||||||
|
# Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action.
|
||||||
|
self._queues["action"].extend(actions[: self.config.n_action_steps])
|
||||||
|
|
||||||
|
action = self._queues["action"].popleft()
|
||||||
|
return action
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plan(self, z: Tensor) -> Tensor:
|
||||||
|
"""Plan sequence of actions using TD-MPC inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z: (batch, latent_dim,) tensor for the initial state.
|
||||||
|
Returns:
|
||||||
|
(horizon, batch, action_dim,) tensor for the planned trajectory of actions.
|
||||||
|
"""
|
||||||
|
device = get_device_from_parameters(self)
|
||||||
|
|
||||||
|
batch_size = z.shape[0]
|
||||||
|
|
||||||
|
# Sample Nπ trajectories from the policy.
|
||||||
|
pi_actions = torch.empty(
|
||||||
|
self.config.horizon,
|
||||||
|
self.config.n_pi_samples,
|
||||||
|
batch_size,
|
||||||
|
self.config.output_shapes["action"][0],
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
if self.config.n_pi_samples > 0:
|
||||||
|
_z = einops.repeat(z, "b d -> n b d", n=self.config.n_pi_samples)
|
||||||
|
for t in range(self.config.horizon):
|
||||||
|
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
|
||||||
|
# helpful for CEM.
|
||||||
|
pi_actions[t] = self.model.pi(_z, self.config.min_std)[0]
|
||||||
|
_z = self.model.latent_dynamics(_z, pi_actions[t])
|
||||||
|
|
||||||
|
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
||||||
|
# trajectories.
|
||||||
|
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||||
|
|
||||||
|
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
||||||
|
# algorithm.
|
||||||
|
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||||
|
mean = torch.zeros(
|
||||||
|
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
|
||||||
|
)
|
||||||
|
# Maybe warm start CEM with the mean from the previous step.
|
||||||
|
if self._prev_mean is not None:
|
||||||
|
mean[:-1] = self._prev_mean[1:]
|
||||||
|
std = self.config.max_std * torch.ones_like(mean)
|
||||||
|
|
||||||
|
for _ in range(self.config.cem_iterations):
|
||||||
|
# Randomly sample action trajectories for the gaussian distribution.
|
||||||
|
std_normal_noise = torch.randn(
|
||||||
|
self.config.horizon,
|
||||||
|
self.config.n_gaussian_samples,
|
||||||
|
batch_size,
|
||||||
|
self.config.output_shapes["action"][0],
|
||||||
|
device=std.device,
|
||||||
|
)
|
||||||
|
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
|
||||||
|
|
||||||
|
# Compute elite actions.
|
||||||
|
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
|
||||||
|
value = self.estimate_value(z, actions).nan_to_num_(0).squeeze()
|
||||||
|
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
|
||||||
|
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
|
||||||
|
# (horizon, n_elites, batch, action_dim)
|
||||||
|
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
|
||||||
|
|
||||||
|
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||||
|
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
|
||||||
|
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
|
||||||
|
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
|
||||||
|
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
|
||||||
|
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||||
|
score /= score.sum(axis=0, keepdim=True)
|
||||||
|
# (horizon, batch, action_dim)
|
||||||
|
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
|
||||||
|
_std = torch.sqrt(
|
||||||
|
torch.sum(
|
||||||
|
einops.rearrange(score, "n b -> n b 1")
|
||||||
|
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Update mean with an exponential moving average, and std with a direct replacement.
|
||||||
|
mean = (
|
||||||
|
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
|
||||||
|
)
|
||||||
|
std = _std.clamp_(self.config.min_std, self.config.max_std)
|
||||||
|
|
||||||
|
# Keep track of the mean for warm-starting subsequent steps.
|
||||||
|
self._prev_mean = mean
|
||||||
|
|
||||||
|
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
|
||||||
|
# scores from the last iteration.
|
||||||
|
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
|
||||||
|
return actions
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def estimate_value(self, z: Tensor, actions: Tensor):
|
||||||
|
"""Estimates the value of a trajectory as per eqn 4 of the FOWM paper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z: (batch, latent_dim) tensor of initial latent states.
|
||||||
|
actions: (horizon, batch, action_dim) tensor of action trajectories.
|
||||||
|
Returns:
|
||||||
|
(batch,) tensor of values.
|
||||||
|
"""
|
||||||
|
# Initialize return and running discount factor.
|
||||||
|
G, running_discount = 0, 1
|
||||||
|
# Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics
|
||||||
|
# model. Keep track of return.
|
||||||
|
for t in range(actions.shape[0]):
|
||||||
|
# We will compute the reward in a moment. First compute the uncertainty regularizer from eqn 4
|
||||||
|
# of the FOWM paper.
|
||||||
|
if self.config.uncertainty_regularizer_coeff > 0:
|
||||||
|
regularization = -(
|
||||||
|
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
regularization = 0
|
||||||
|
# Estimate the next state (latent) and reward.
|
||||||
|
z, reward = self.model.latent_dynamics_and_reward(z, actions[t], discretize_reward=True)
|
||||||
|
# Update the return and running discount.
|
||||||
|
G += running_discount * (reward + regularization)
|
||||||
|
running_discount *= self.config.discount
|
||||||
|
# Add the estimated value of the final state (using the minimum for a conservative estimate).
|
||||||
|
# Do so by predicting the next action, then taking a minimum over the ensemble of state-action value
|
||||||
|
# estimators.
|
||||||
|
# Note: This small amount of added noise seems to help a bit at inference time as observed by success
|
||||||
|
# metrics over 50 episodes of xarm_lift_medium_replay.
|
||||||
|
next_action = self.model.pi(z, self.config.min_std)[0] # (batch, action_dim)
|
||||||
|
terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
|
||||||
|
# Randomly choose 2 of the Qs for terminal value estimation (as in App C. of the FOWM paper).
|
||||||
|
if self.config.q_ensemble_size > 2:
|
||||||
|
G += (
|
||||||
|
running_discount
|
||||||
|
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
G += running_discount * torch.min(terminal_values, dim=0)[0]
|
||||||
|
# Finally, also regularize the terminal value.
|
||||||
|
if self.config.uncertainty_regularizer_coeff > 0:
|
||||||
|
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||||
|
return G
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||||
|
"""Run the batch through the model and compute the loss.
|
||||||
|
|
||||||
|
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||||
|
"""
|
||||||
|
device = get_device_from_parameters(self)
|
||||||
|
|
||||||
|
batch = self.normalize_inputs(batch)
|
||||||
|
if self._use_image:
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
|
batch["observation.image"] = batch[self.input_image_key]
|
||||||
|
batch = self.normalize_targets(batch)
|
||||||
|
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
# (b, t) -> (t, b)
|
||||||
|
for key in batch:
|
||||||
|
if batch[key].ndim > 1:
|
||||||
|
batch[key] = batch[key].transpose(1, 0)
|
||||||
|
|
||||||
|
action = batch["action"] # (t, b, action_dim)
|
||||||
|
reward = batch["next.reward"] # (t, b)
|
||||||
|
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||||
|
|
||||||
|
# Apply random image augmentations.
|
||||||
|
if self._use_image and self.config.max_random_shift_ratio > 0:
|
||||||
|
observations["observation.image"] = flatten_forward_unflatten(
|
||||||
|
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||||
|
observations["observation.image"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the current observation for predicting trajectories, and all future observations for use in
|
||||||
|
# the latent consistency loss and TD loss.
|
||||||
|
current_observation, next_observations = {}, {}
|
||||||
|
for k in observations:
|
||||||
|
current_observation[k] = observations[k][0]
|
||||||
|
next_observations[k] = observations[k][1:]
|
||||||
|
horizon, batch_size = next_observations[
|
||||||
|
"observation.image" if self._use_image else "observation.environment_state"
|
||||||
|
].shape[:2]
|
||||||
|
|
||||||
|
# Run latent rollout using the latent dynamics model and policy model.
|
||||||
|
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||||
|
# gives us a next `z`.
|
||||||
|
batch_size = batch["index"].shape[0]
|
||||||
|
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||||
|
z_preds[0] = self.model.encode(current_observation)
|
||||||
|
reward_preds = torch.empty(horizon, batch_size, self.config.num_bins, device=device)
|
||||||
|
for t in range(horizon):
|
||||||
|
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
|
||||||
|
|
||||||
|
# Compute Q value predictions based on the latent rollout.
|
||||||
|
q_preds_ensemble = self.model.Qs(
|
||||||
|
z_preds[:-1], action, return_type="all"
|
||||||
|
) # (ensemble, horizon, batch)
|
||||||
|
info.update({"Q": q_preds_ensemble.mean().item()})
|
||||||
|
|
||||||
|
# Compute various targets with stopgrad.
|
||||||
|
with torch.no_grad():
|
||||||
|
# Latent state consistency targets.
|
||||||
|
z_targets = self.model_target.encode(next_observations)
|
||||||
|
# Compute the TD-target from a reward and the next observation
|
||||||
|
pi = self.model.pi(z_targets)[0]
|
||||||
|
td_targets = (
|
||||||
|
reward
|
||||||
|
+ self.config.discount * self.model_target.Qs(z_targets, pi, return_type="min").squeeze()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute losses.
|
||||||
|
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
|
||||||
|
# future have less impact on the loss. Note: unsqueeze will let us broadcast to (seq, batch).
|
||||||
|
temporal_loss_coeffs = torch.pow(
|
||||||
|
self.config.temporal_decay_coeff, torch.arange(horizon, device=device)
|
||||||
|
).unsqueeze(-1)
|
||||||
|
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
|
||||||
|
# predicted from the (target model's) observation encoder.
|
||||||
|
consistency_loss = (
|
||||||
|
(
|
||||||
|
temporal_loss_coeffs
|
||||||
|
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
||||||
|
# `z_preds` depends on the current observation and the actions.
|
||||||
|
* ~batch["observation.state_is_pad"][0]
|
||||||
|
* ~batch["action_is_pad"]
|
||||||
|
# `z_targets` depends on the next observation.
|
||||||
|
* ~batch["observation.state_is_pad"][1:]
|
||||||
|
)
|
||||||
|
.sum(0)
|
||||||
|
.mean()
|
||||||
|
)
|
||||||
|
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
|
||||||
|
# rewards.
|
||||||
|
reward_loss = (
|
||||||
|
(
|
||||||
|
temporal_loss_coeffs
|
||||||
|
* soft_cross_entropy(reward_preds, reward, self.config)
|
||||||
|
* ~batch["next.reward_is_pad"]
|
||||||
|
# `reward_preds` depends on the current observation and the actions.
|
||||||
|
* ~batch["observation.state_is_pad"][0]
|
||||||
|
* ~batch["action_is_pad"]
|
||||||
|
)
|
||||||
|
.sum(0)
|
||||||
|
.mean()
|
||||||
|
)
|
||||||
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
|
ce_value_loss = 0.0
|
||||||
|
for i in range(self.config.q_ensemble_size):
|
||||||
|
ce_value_loss += soft_cross_entropy(q_preds_ensemble[i], td_targets, self.config)
|
||||||
|
|
||||||
|
q_value_loss = (
|
||||||
|
(
|
||||||
|
temporal_loss_coeffs
|
||||||
|
* ce_value_loss
|
||||||
|
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||||
|
* ~batch["observation.state_is_pad"][0]
|
||||||
|
* ~batch["action_is_pad"]
|
||||||
|
# q_targets depends on the reward and the next observations.
|
||||||
|
* ~batch["next.reward_is_pad"]
|
||||||
|
* ~batch["observation.state_is_pad"][1:]
|
||||||
|
)
|
||||||
|
.sum(0)
|
||||||
|
.mean()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
||||||
|
# We won't need these gradients again so detach.
|
||||||
|
z_preds = z_preds.detach()
|
||||||
|
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
|
||||||
|
qs = self.model_target.Qs(z_preds[:-1], action_preds, return_type="avg")
|
||||||
|
self.scale.update(qs[0])
|
||||||
|
qs = self.scale(qs)
|
||||||
|
|
||||||
|
rho = torch.pow(self.config.rho, torch.arange(len(qs), device=qs.device)).unsqueeze(-1)
|
||||||
|
|
||||||
|
# mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
|
||||||
|
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
|
||||||
|
# other losses.
|
||||||
|
# TODO(alexander-soare): Take the sum over the temporal dimension and check that training still works
|
||||||
|
# as well as expected.
|
||||||
|
pi_loss = (
|
||||||
|
(self.config.entropy_coef * log_pis - qs).mean(dim=-1)
|
||||||
|
* rho
|
||||||
|
# * temporal_loss_coeffs
|
||||||
|
# `action_preds` depends on the first observation and the actions.
|
||||||
|
* ~batch["observation.state_is_pad"][0]
|
||||||
|
* ~batch["action_is_pad"]
|
||||||
|
).mean()
|
||||||
|
|
||||||
|
loss = (
|
||||||
|
self.config.consistency_coeff * consistency_loss
|
||||||
|
+ self.config.reward_coeff * reward_loss
|
||||||
|
+ self.config.value_coeff * q_value_loss
|
||||||
|
+ self.config.pi_coeff * pi_loss
|
||||||
|
)
|
||||||
|
|
||||||
|
info.update(
|
||||||
|
{
|
||||||
|
"consistency_loss": consistency_loss.item(),
|
||||||
|
"reward_loss": reward_loss.item(),
|
||||||
|
"Q_value_loss": q_value_loss.item(),
|
||||||
|
"pi_loss": pi_loss.item(),
|
||||||
|
"loss": loss,
|
||||||
|
"sum_loss": loss.item() * self.config.horizon,
|
||||||
|
"pi_scale": float(self.scale.value),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Undo (b, t) -> (t, b).
|
||||||
|
for key in batch:
|
||||||
|
if batch[key].ndim > 1:
|
||||||
|
batch[key] = batch[key].transpose(1, 0)
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
"""Update the target model's parameters with an EMA step."""
|
||||||
|
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
|
||||||
|
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
|
||||||
|
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
|
||||||
|
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||||
|
|
||||||
|
|
||||||
|
class TDMPC2WorldModel(nn.Module):
|
||||||
|
"""Latent dynamics model used in TD-MPC2."""
|
||||||
|
|
||||||
|
def __init__(self, config: TDMPC2Config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self._encoder = TDMPC2ObservationEncoder(config)
|
||||||
|
|
||||||
|
# Define latent dynamics head
|
||||||
|
self._dynamics = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||||
|
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||||
|
NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
|
||||||
|
|
||||||
|
# Define reward head
|
||||||
|
self._reward = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||||
|
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||||
|
nn.Linear(config.mlp_dim, max(config.num_bins, 1)))
|
||||||
|
|
||||||
|
# Define policy head
|
||||||
|
self._pi = nn.Sequential(NormedLinear(config.latent_dim, config.mlp_dim),
|
||||||
|
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||||
|
nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0]))
|
||||||
|
|
||||||
|
# Define ensemble of Q functions
|
||||||
|
self._Qs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.Sequential(
|
||||||
|
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim, dropout=config.dropout),
|
||||||
|
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||||
|
nn.Linear(config.mlp_dim, max(config.num_bins, 1))
|
||||||
|
) for _ in range(config.q_ensemble_size)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
|
||||||
|
|
||||||
|
self.log_std_min = torch.tensor(config.log_std_min)
|
||||||
|
self.log_std_dif = torch.tensor(config.log_std_max) - self.log_std_min
|
||||||
|
|
||||||
|
self.bins = torch.linspace(config.vmin, config.vmax, config.num_bins)
|
||||||
|
self.config.bin_size = (config.vmax - config.vmin) / (config.num_bins - 1)
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
"""Initialize model weights.
|
||||||
|
Custom weight initializations proposed in TD-MPC2.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def _apply_fn(m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.Embedding):
|
||||||
|
nn.init.uniform_(m.weight, -0.02, 0.02)
|
||||||
|
elif isinstance(m, nn.ParameterList):
|
||||||
|
for i, p in enumerate(m):
|
||||||
|
if p.dim() == 3: # Linear
|
||||||
|
nn.init.trunc_normal_(p, std=0.02) # Weight
|
||||||
|
nn.init.constant_(m[i+1], 0) # Bias
|
||||||
|
|
||||||
|
self.apply(_apply_fn)
|
||||||
|
|
||||||
|
# initialize parameters of the
|
||||||
|
for m in [self._reward, *self._Qs]:
|
||||||
|
assert isinstance(
|
||||||
|
m[-1], nn.Linear
|
||||||
|
), "Sanity check. The last linear layer needs 0 initialization on weights."
|
||||||
|
nn.init.zeros_(m[-1].weight)
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Overriding `to` method to also move additional tensors to device.
|
||||||
|
"""
|
||||||
|
super().to(*args, **kwargs)
|
||||||
|
self.log_std_min = self.log_std_min.to(*args, **kwargs)
|
||||||
|
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
|
||||||
|
self.bins = self.bins.to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def encode(self, obs: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Encodes an observation into its latent representation."""
|
||||||
|
return self._encoder(obs)
|
||||||
|
|
||||||
|
def latent_dynamics_and_reward(
|
||||||
|
self, z: Tensor, a: Tensor, discretize_reward: bool = False
|
||||||
|
) -> tuple[Tensor, Tensor, bool]:
|
||||||
|
"""Predict the next state's latent representation and the reward given a current latent and action.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||||
|
a: (*, action_dim) tensor for the action to be applied.
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
- (*, latent_dim) tensor for the next state's latent representation.
|
||||||
|
- (*,) tensor for the estimated reward.
|
||||||
|
"""
|
||||||
|
x = torch.cat([z, a], dim=-1)
|
||||||
|
reward = self._reward(x).squeeze(-1)
|
||||||
|
if discretize_reward:
|
||||||
|
reward = two_hot_inv(reward, self.bins)
|
||||||
|
return self._dynamics(x), reward
|
||||||
|
|
||||||
|
def latent_dynamics(self, z: Tensor, a: Tensor) -> Tensor:
|
||||||
|
"""Predict the next state's latent representation given a current latent and action.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||||
|
a: (*, action_dim) tensor for the action to be applied.
|
||||||
|
Returns:
|
||||||
|
(*, latent_dim) tensor for the next state's latent representation.
|
||||||
|
"""
|
||||||
|
x = torch.cat([z, a], dim=-1)
|
||||||
|
return self._dynamics(x)
|
||||||
|
|
||||||
|
def pi(self, z: Tensor, std: float = 0.0) -> Tensor:
|
||||||
|
"""Samples an action from the learned policy.
|
||||||
|
|
||||||
|
The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when
|
||||||
|
generating rollouts for online training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||||
|
std: The standard deviation of the injected noise.
|
||||||
|
Returns:
|
||||||
|
(*, action_dim) tensor for the sampled action.
|
||||||
|
"""
|
||||||
|
mu, log_std = self._pi(z).chunk(2, dim=-1)
|
||||||
|
log_std = self.log_std_min + 0.5 * self.log_std_dif * (torch.tanh(log_std) + 1)
|
||||||
|
eps = torch.randn_like(mu)
|
||||||
|
|
||||||
|
log_pi = gaussian_logprob(eps, log_std)
|
||||||
|
pi = mu + eps * log_std.exp()
|
||||||
|
mu, pi, log_pi = squash(mu, pi, log_pi)
|
||||||
|
|
||||||
|
return pi, mu, log_pi, log_std
|
||||||
|
|
||||||
|
def Qs(self, z: Tensor, a: Tensor, return_type: str = "min", target=False) -> Tensor: # noqa: N802
|
||||||
|
"""Predict state-action value for all of the learned Q functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||||
|
a: (*, action_dim) tensor for the action to be applied.
|
||||||
|
return_type: either 'min' or 'all' otherwise the average is returned
|
||||||
|
Returns:
|
||||||
|
(q_ensemble, *) tensor for the value predictions of each learned Q function in the ensemble or the average or min
|
||||||
|
"""
|
||||||
|
x = torch.cat([z, a], dim=-1)
|
||||||
|
|
||||||
|
if target:
|
||||||
|
out = torch.stack([q(x).squeeze(-1) for q in self._target_Qs], dim=0)
|
||||||
|
else:
|
||||||
|
out = torch.stack([q(x).squeeze(-1) for q in self._Qs], dim=0)
|
||||||
|
|
||||||
|
if return_type == "all":
|
||||||
|
return out
|
||||||
|
|
||||||
|
Q1, Q2 = out[np.random.choice(len(self._Qs), size=2, replace=False)]
|
||||||
|
Q1, Q2 = two_hot_inv(Q1, self.bins), two_hot_inv(Q2, self.bins)
|
||||||
|
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||||
|
|
||||||
|
|
||||||
|
class TDMPC2ObservationEncoder(nn.Module):
|
||||||
|
"""Encode image and/or state vector observations."""
|
||||||
|
|
||||||
|
def __init__(self, config: TDMPC2Config):
|
||||||
|
"""
|
||||||
|
Creates encoders for pixel and/or state modalities.
|
||||||
|
TODO(alexander-soare): The original work allows for multiple images by concatenating them along the
|
||||||
|
channel dimension. Re-implement this capability.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Define the observation encoder whether its pixels or states
|
||||||
|
encoder_dict = {}
|
||||||
|
for obs_key in config.input_shapes:
|
||||||
|
if "observation.image" in config.input_shapes:
|
||||||
|
encoder_module = nn.Sequential(
|
||||||
|
nn.Conv2d(config.input_shapes[obs_key][0], config.image_encoder_hidden_dim, 7, stride=2),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=1),
|
||||||
|
)
|
||||||
|
dummy_batch = torch.zeros(1, *config.input_shapes[obs_key])
|
||||||
|
with torch.inference_mode():
|
||||||
|
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||||
|
encoder_module.extend(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Flatten(),
|
||||||
|
NormedLinear(np.prod(out_shape), config.latent_dim, act=SimNorm(config.simnorm_dim)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif "observation.state" in config.input_shapes:
|
||||||
|
encoder_module = nn.ModuleList()
|
||||||
|
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.enc_dim))
|
||||||
|
assert config.num_enc_layers > 0
|
||||||
|
for _ in range(config.num_enc_layers - 1):
|
||||||
|
encoder_module.append(NormedLinear(config.enc_dim, config.enc_dim))
|
||||||
|
encoder_module.append(NormedLinear(config.enc_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
|
||||||
|
encoder_module = nn.Sequential(*encoder_module)
|
||||||
|
|
||||||
|
elif "observation.environment_state" in config.input_shapes:
|
||||||
|
encoder_module = nn.ModuleList()
|
||||||
|
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.enc_dim))
|
||||||
|
assert config.num_enc_layers > 0
|
||||||
|
for _ in range(config.num_enc_layers - 1):
|
||||||
|
encoder_module.append(NormedLinear(config.enc_dim, config.enc_dim))
|
||||||
|
encoder_module.append(NormedLinear(config.enc_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
|
||||||
|
encoder_module = nn.Sequential(*encoder_module)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"No corresponding encoder module for key {obs_key}.")
|
||||||
|
|
||||||
|
encoder_dict[obs_key] = encoder_module
|
||||||
|
|
||||||
|
self.encoder = nn.ModuleDict(encoder_dict)
|
||||||
|
|
||||||
|
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Encode the image and/or state vector.
|
||||||
|
|
||||||
|
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||||
|
over all features.
|
||||||
|
"""
|
||||||
|
feat = []
|
||||||
|
for obs_key in self.config.input_shapes:
|
||||||
|
if "observation.image" in obs_key:
|
||||||
|
feat.append(flatten_forward_unflatten(self.encoder[obs_key], obs_dict[obs_key]))
|
||||||
|
else:
|
||||||
|
feat.append(self.encoder[obs_key](obs_dict[obs_key]))
|
||||||
|
return torch.stack(feat, dim=0).mean(0)
|
||||||
|
|
||||||
|
|
||||||
|
def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor:
|
||||||
|
"""Randomly shifts images horizontally and vertically.
|
||||||
|
|
||||||
|
Adapted from https://github.com/facebookresearch/drqv2
|
||||||
|
"""
|
||||||
|
b, _, h, w = x.size()
|
||||||
|
assert h == w, "non-square images not handled yet"
|
||||||
|
pad = int(round(max_random_shift_ratio * h))
|
||||||
|
x = F.pad(x, tuple([pad] * 4), "replicate")
|
||||||
|
eps = 1.0 / (h + 2 * pad)
|
||||||
|
arange = torch.linspace(
|
||||||
|
-1.0 + eps,
|
||||||
|
1.0 - eps,
|
||||||
|
h + 2 * pad,
|
||||||
|
device=x.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)[:h]
|
||||||
|
arange = einops.repeat(arange, "w -> h w 1", h=h)
|
||||||
|
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
|
||||||
|
base_grid = einops.repeat(base_grid, "h w c -> b h w c", b=b)
|
||||||
|
# A random shift in units of pixels and within the boundaries of the padding.
|
||||||
|
shift = torch.randint(
|
||||||
|
0,
|
||||||
|
2 * pad + 1,
|
||||||
|
size=(b, 1, 1, 2),
|
||||||
|
device=x.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
shift *= 2.0 / (h + 2 * pad)
|
||||||
|
grid = base_grid + shift
|
||||||
|
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
|
||||||
|
|
||||||
|
|
||||||
|
def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
|
||||||
|
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
|
||||||
|
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
|
||||||
|
for (n_p_ema, p_ema), (n_p, p) in zip(
|
||||||
|
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
|
||||||
|
):
|
||||||
|
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
|
||||||
|
if isinstance(p, dict):
|
||||||
|
raise RuntimeError("Dict parameter not supported")
|
||||||
|
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
|
||||||
|
# Copy BatchNorm parameters, and non-trainable parameters directly.
|
||||||
|
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
|
||||||
|
with torch.no_grad():
|
||||||
|
p_ema.mul_(alpha)
|
||||||
|
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||||
|
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||||
|
(B, *), where * is any number of dimensions.
|
||||||
|
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions, generally
|
||||||
|
different from *.
|
||||||
|
Returns:
|
||||||
|
A return value from the callable reshaped to (**, *).
|
||||||
|
"""
|
||||||
|
if image_tensor.ndim == 4:
|
||||||
|
return fn(image_tensor)
|
||||||
|
start_dims = image_tensor.shape[:-3]
|
||||||
|
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||||
|
flat_out = fn(inp)
|
||||||
|
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||||
|
|
||||||
|
|
||||||
|
class RunningScale:
|
||||||
|
"""Running trimmed scale estimator."""
|
||||||
|
|
||||||
|
def __init__(self, tau):
|
||||||
|
self.tau = tau
|
||||||
|
self._value = torch.ones(1, dtype=torch.float32, device=torch.device("cuda"))
|
||||||
|
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device("cuda"))
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return dict(value=self._value, percentiles=self._percentiles)
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self._value.data.copy_(state_dict["value"])
|
||||||
|
self._percentiles.data.copy_(state_dict["percentiles"])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return self._value.cpu().item()
|
||||||
|
|
||||||
|
def _percentile(self, x):
|
||||||
|
x_dtype, x_shape = x.dtype, x.shape
|
||||||
|
x = x.view(x.shape[0], -1)
|
||||||
|
in_sorted, _ = torch.sort(x, dim=0)
|
||||||
|
positions = self._percentiles * (x.shape[0] - 1) / 100
|
||||||
|
floored = torch.floor(positions)
|
||||||
|
ceiled = floored + 1
|
||||||
|
ceiled[ceiled > x.shape[0] - 1] = x.shape[0] - 1
|
||||||
|
weight_ceiled = positions - floored
|
||||||
|
weight_floored = 1.0 - weight_ceiled
|
||||||
|
d0 = in_sorted[floored.long(), :] * weight_floored[:, None]
|
||||||
|
d1 = in_sorted[ceiled.long(), :] * weight_ceiled[:, None]
|
||||||
|
return (d0 + d1).view(-1, *x_shape[1:]).type(x_dtype)
|
||||||
|
|
||||||
|
def update(self, x):
|
||||||
|
percentiles = self._percentile(x.detach())
|
||||||
|
value = torch.clamp(percentiles[1] - percentiles[0], min=1.0)
|
||||||
|
self._value.data.lerp_(value, self.tau)
|
||||||
|
|
||||||
|
def __call__(self, x, update=False):
|
||||||
|
if update:
|
||||||
|
self.update(x)
|
||||||
|
return x * (1 / self.value)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"RunningScale(S: {self.value})"
|
|
@ -0,0 +1,156 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from functorch import combine_state_for_ensemble
|
||||||
|
|
||||||
|
|
||||||
|
class Ensemble(nn.Module):
|
||||||
|
"""
|
||||||
|
Vectorized ensemble of modules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, modules, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
modules = nn.ModuleList(modules)
|
||||||
|
fn, params, _ = combine_state_for_ensemble(modules)
|
||||||
|
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs)
|
||||||
|
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
|
||||||
|
self._repr = str(modules)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.vmap([p for p in self.params], (), *args, **kwargs)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'Vectorized ' + self._repr
|
||||||
|
|
||||||
|
class SimNorm(nn.Module):
|
||||||
|
"""
|
||||||
|
Simplicial normalization.
|
||||||
|
Adapted from https://arxiv.org/abs/2204.00616.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
shp = x.shape
|
||||||
|
x = x.view(*shp[:-1], -1, self.dim)
|
||||||
|
x = F.softmax(x, dim=-1)
|
||||||
|
return x.view(*shp)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"SimNorm(dim={self.dim})"
|
||||||
|
|
||||||
|
|
||||||
|
class NormedLinear(nn.Linear):
|
||||||
|
"""
|
||||||
|
Linear layer with LayerNorm, activation, and optionally dropout.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ln = nn.LayerNorm(self.out_features)
|
||||||
|
self.act = act
|
||||||
|
self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = super().forward(x)
|
||||||
|
if self.dropout:
|
||||||
|
x = self.dropout(x)
|
||||||
|
return self.act(self.ln(x))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
|
||||||
|
return f"NormedLinear(in_features={self.in_features}, "\
|
||||||
|
f"out_features={self.out_features}, "\
|
||||||
|
f"bias={self.bias is not None}{repr_dropout}, "\
|
||||||
|
f"act={self.act.__class__.__name__})"
|
||||||
|
|
||||||
|
|
||||||
|
def soft_ce(pred, target, cfg):
|
||||||
|
"""Computes the cross entropy loss between predictions and soft targets."""
|
||||||
|
pred = F.log_softmax(pred, dim=-1)
|
||||||
|
target = two_hot(target, cfg)
|
||||||
|
return -(target * pred).sum(-1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def log_std(x, low, dif):
|
||||||
|
return low + 0.5 * dif * (torch.tanh(x) + 1)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def _gaussian_residual(eps, log_std):
|
||||||
|
return -0.5 * eps.pow(2) - log_std
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def _gaussian_logprob(residual):
|
||||||
|
return residual - 0.5 * torch.log(2 * torch.pi)
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian_logprob(eps, log_std, size=None):
|
||||||
|
"""Compute Gaussian log probability."""
|
||||||
|
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
|
||||||
|
if size is None:
|
||||||
|
size = eps.size(-1)
|
||||||
|
return _gaussian_logprob(residual) * size
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def _squash(pi):
|
||||||
|
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def squash(mu, pi, log_pi):
|
||||||
|
"""Apply squashing function."""
|
||||||
|
mu = torch.tanh(mu)
|
||||||
|
pi = torch.tanh(pi)
|
||||||
|
log_pi -= _squash(pi).sum(-1, keepdim=True)
|
||||||
|
return mu, pi, log_pi
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def symlog(x):
|
||||||
|
"""
|
||||||
|
Symmetric logarithmic function.
|
||||||
|
Adapted from https://github.com/danijar/dreamerv3.
|
||||||
|
"""
|
||||||
|
return torch.sign(x) * torch.log(1 + torch.abs(x))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def symexp(x):
|
||||||
|
"""
|
||||||
|
Symmetric exponential function.
|
||||||
|
Adapted from https://github.com/danijar/dreamerv3.
|
||||||
|
"""
|
||||||
|
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def two_hot(x, cfg):
|
||||||
|
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
|
||||||
|
if cfg.num_bins == 0:
|
||||||
|
return x
|
||||||
|
elif cfg.num_bins == 1:
|
||||||
|
return symlog(x)
|
||||||
|
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1)
|
||||||
|
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
|
||||||
|
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1)
|
||||||
|
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
|
||||||
|
soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset)
|
||||||
|
soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset)
|
||||||
|
return soft_two_hot
|
||||||
|
|
||||||
|
def two_hot_inv(x, bins):
|
||||||
|
"""Converts a batch of soft two-hot encoded vectors to scalars."""
|
||||||
|
num_bins = bins.shape[0]
|
||||||
|
if num_bins == 0:
|
||||||
|
return x
|
||||||
|
elif num_bins == 1:
|
||||||
|
return symexp(x)
|
||||||
|
|
||||||
|
x = F.softmax(x, dim=-1)
|
||||||
|
x = torch.sum(x * bins, dim=-1, keepdim=True)
|
||||||
|
return symexp(x)
|
Loading…
Reference in New Issue