fixes and updated comments

This commit is contained in:
Michel Aractingi 2024-11-26 09:46:59 +00:00
parent 15090c2544
commit 16edbbdeee
3 changed files with 223 additions and 202 deletions

View File

@ -19,7 +19,7 @@ from dataclasses import dataclass, field
@dataclass
class TDMPC2Config:
"""Configuration class for TDMPCPolicy.
"""Configuration class for TDMPC2Policy.
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
camera observations.
@ -77,18 +77,9 @@ class TDMPC2Config:
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
is applied. Note that the input images are assumed to be square for this augmentation.
reward_coeff: Loss weighting coefficient for the reward regression loss.
expectile_weight: Weighting (τ) used in expectile regression for the state value function (V).
v_pred < v_target is weighted by τ and v_pred >= v_target is weighted by (1-τ). τ is expected to
be in [0, 1]. Setting τ closer to 1 results in a more "optimistic" V. This is sensible to do
because v_target is obtained by evaluating the learned state-action value functions (Q) with
in-sample actions that may not be always optimal.
value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state
value (V) expectile regression loss.
consistency_coeff: Loss weighting coefficient for the consistency loss.
advantage_scaling: A factor by which the advantages are scaled prior to exponentiation for advantage
weighted regression of the policy (π) estimator parameters. Note that the exponentiated advantages
are clamped at 100.0.
pi_coeff: Loss weighting coefficient for the action regression loss.
temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time-
steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the
current time step.
@ -126,9 +117,12 @@ class TDMPC2Config:
state_encoder_hidden_dim: int = 256
latent_dim: int = 512
q_ensemble_size: int = 5
num_enc_layers: int = 2
mlp_dim: int = 512
# Reinforcement learning.
discount: float = 0.9
simnorm_dim: int = 8
dropout: float = 0.01
# actor
log_std_min: float = -10

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 Nicklas Hansen and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,11 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of Finetuning Offline World Models in the Real World.
"""Implementation of TD-MPC2: Scalable, Robust World Models for Continuous Control
The comments in this code may sometimes refer to these references:
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955)
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029)
We refer to the main paper and codebase:
TD-MPC2 paper: (https://arxiv.org/abs/2310.16828)
TD-MPC2 code: (https://github.com/nicklashansen/tdmpc2)
"""
# ruff: noqa: N806
@ -38,8 +38,16 @@ from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.tdmpc2.tdmpc2_utils import (
NormedLinear,
SimNorm,
gaussian_logprob,
soft_cross_entropy,
squash,
two_hot_inv,
)
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash, soft_cross_entropy
class TDMPC2Policy(
nn.Module,
@ -48,22 +56,7 @@ class TDMPC2Policy(
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "tdmpc2"],
):
"""Implementation of TD-MPC2 learning + inference.
Please note several warnings for this policy.
- Evaluation of pretrained weights created with the original FOWM code
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
process communication to use the xarm environment from FOWM. This is because our xarm
environment uses newer dependencies and does not match the environment in FOWM. See
https://github.com/huggingface/lerobot/pull/103 for implementation details.
- We have NOT checked that training on LeRobot reproduces the results from FOWM.
- Nevertheless, we have verified that we can train TD-MPC for PushT. See
`lerobot/configs/policy/tdmpc2_pusht_keypoints.yaml`.
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
match our xarm environment.
"""
"""Implementation of TD-MPC2 learning + inference."""
name = "tdmpc2"
@ -83,6 +76,8 @@ class TDMPC2Policy(
config = TDMPC2Config()
self.config = config
self.model = TDMPC2WorldModel(config)
# TODO (michel-aractingi) temp fix for gpu
self.model = self.model.to("cuda:0")
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
@ -109,7 +104,9 @@ class TDMPC2Policy(
self._use_env_state = True
self.scale = RunningScale(self.config.target_model_momentum)
self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length
self.discount = (
self.config.discount
) # TODO (michel-aractingi) downscale discount according to episode length
self.reset()
@ -204,7 +201,7 @@ class TDMPC2Policy(
for t in range(self.config.horizon):
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
# helpful for CEM.
pi_actions[t] = self.model.pi(_z, self.config.min_std)[0]
pi_actions[t] = self.model.pi(_z)[0]
_z = self.model.latent_dynamics(_z, pi_actions[t])
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
@ -249,13 +246,16 @@ class TDMPC2Policy(
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim)
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (score.sum(0) + 1e-9)
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (
einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9
)
std = torch.sqrt(
torch.sum(
einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2,
dim=1,
) / (score.sum(0) + 1e-9)
)
/ (einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9)
).clamp_(self.config.min_std, self.config.max_std)
# Keep track of the mean for warm-starting subsequent steps.
@ -290,7 +290,7 @@ class TDMPC2Policy(
# next_action = self.model.pi(z)[0] # (batch, action_dim)
# terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type='avg')
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type="avg")
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
@ -358,7 +358,8 @@ class TDMPC2Policy(
pi = self.model.pi(z_targets)[0]
td_targets = (
reward
+ self.config.discount * self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
+ self.config.discount
* self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
)
# Compute losses.
@ -429,7 +430,9 @@ class TDMPC2Policy(
self.scale.update(qs[0])
qs = self.scale(qs)
rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(-1)
rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(
-1
)
pi_loss = (
(self.config.entropy_coef * log_pis - qs).mean(dim=(1, 2))
@ -470,6 +473,7 @@ class TDMPC2Policy(
"""Update the target model's using polyak averaging."""
self.model.update_target_Q()
class TDMPC2WorldModel(nn.Module):
"""Latent dynamics model used in TD-MPC2."""
@ -480,28 +484,39 @@ class TDMPC2WorldModel(nn.Module):
self._encoder = TDMPC2ObservationEncoder(config)
# Define latent dynamics head
self._dynamics = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
self._dynamics = nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim),
NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)),
)
# Define reward head
self._reward = nn.Sequential(NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
self._reward = nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, max(config.num_bins, 1)))
nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
)
# Define policy head
self._pi = nn.Sequential(NormedLinear(config.latent_dim, config.mlp_dim),
self._pi = nn.Sequential(
NormedLinear(config.latent_dim, config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0]))
nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0]),
)
# Define ensemble of Q functions
self._Qs = nn.ModuleList(
[
nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim, dropout=config.dropout),
NormedLinear(
config.latent_dim + config.output_shapes["action"][0],
config.mlp_dim,
dropout=config.dropout,
),
NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, max(config.num_bins, 1))
) for _ in range(config.q_ensemble_size)
nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
)
for _ in range(config.q_ensemble_size)
]
)
@ -520,6 +535,7 @@ class TDMPC2WorldModel(nn.Module):
Custom weight initializations proposed in TD-MPC2.
"""
def _apply_fn(m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
@ -641,7 +657,7 @@ class TDMPC2WorldModel(nn.Module):
Soft-update target Q-networks using Polyak averaging.
"""
with torch.no_grad():
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters(), strict=False):
p_target.data.lerp_(p.data, self.config.target_model_momentum)
@ -672,7 +688,7 @@ class TDMPC2ObservationEncoder(nn.Module):
)
dummy_batch = torch.zeros(1, *config.input_shapes[obs_key])
with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
out_shape = encoder_module(dummy_batch).shape[1:]
encoder_module.extend(
nn.Sequential(
nn.Flatten(),
@ -680,28 +696,30 @@ class TDMPC2ObservationEncoder(nn.Module):
)
)
elif "observation.state" in config.input_shapes:
elif (
"observation.state" in config.input_shapes
or "observation.environment_state" in config.input_shapes
):
encoder_module = nn.ModuleList()
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim))
encoder_module.append(
NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)
)
assert config.num_enc_layers > 0
for _ in range(config.num_enc_layers - 1):
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
encoder_module = nn.Sequential(*encoder_module)
elif "observation.environment_state" in config.input_shapes:
encoder_module = nn.ModuleList()
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim))
assert config.num_enc_layers > 0
for _ in range(config.num_enc_layers - 1):
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
encoder_module.append(
NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)
)
encoder_module.append(
NormedLinear(
config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)
)
)
encoder_module = nn.Sequential(*encoder_module)
else:
raise NotImplementedError(f"No corresponding encoder module for key {obs_key}.")
encoder_dict[obs_key] = encoder_module
encoder_dict[obs_key.replace(".", "")] = encoder_module
self.encoder = nn.ModuleDict(encoder_dict)
@ -714,9 +732,11 @@ class TDMPC2ObservationEncoder(nn.Module):
feat = []
for obs_key in self.config.input_shapes:
if "observation.image" in obs_key:
feat.append(flatten_forward_unflatten(self.encoder[obs_key], obs_dict[obs_key]))
feat.append(
flatten_forward_unflatten(self.encoder[obs_key.replace(".", "")], obs_dict[obs_key])
)
else:
feat.append(self.encoder[obs_key](obs_dict[obs_key]))
feat.append(self.encoder[obs_key.replace(".", "")](obs_dict[obs_key]))
return torch.stack(feat, dim=0).mean(0)

View File

@ -13,7 +13,7 @@ class Ensemble(nn.Module):
super().__init__()
modules = nn.ModuleList(modules)
fn, params, _ = combine_state_for_ensemble(modules)
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs)
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness="different", **kwargs)
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
self._repr = str(modules)
@ -21,7 +21,8 @@ class Ensemble(nn.Module):
return self.vmap([p for p in self.params], (), *args, **kwargs)
def __repr__(self):
return 'Vectorized ' + self._repr
return "Vectorized " + self._repr
class SimNorm(nn.Module):
"""
@ -48,7 +49,7 @@ class NormedLinear(nn.Linear):
Linear layer with LayerNorm, activation, and optionally dropout.
"""
def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs):
def __init__(self, *args, dropout=0.0, act=nn.Mish(inplace=True), **kwargs):
super().__init__(*args, **kwargs)
self.ln = nn.LayerNorm(self.out_features)
self.act = act
@ -62,16 +63,21 @@ class NormedLinear(nn.Linear):
def __repr__(self):
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
return f"NormedLinear(in_features={self.in_features}, "\
f"out_features={self.out_features}, "\
f"bias={self.bias is not None}{repr_dropout}, "\
return (
f"NormedLinear(in_features={self.in_features}, "
f"out_features={self.out_features}, "
f"bias={self.bias is not None}{repr_dropout}, "
f"act={self.act.__class__.__name__})"
)
def soft_cross_entropy(pred, target, cfg):
"""Computes the cross entropy loss between predictions and soft targets."""
pred = F.log_softmax(pred, dim=-1)
target = two_hot(target, cfg)
import pudb
pudb.set_trace()
return -(target * pred).sum(-1, keepdim=True)
@ -135,14 +141,15 @@ def two_hot(x, cfg):
return x
elif cfg.num_bins == 1:
return symlog(x)
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1)
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1)
bin_offset = (x - cfg.vmin) / cfg.bin_size - bin_idx.float()
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset)
soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset)
soft_two_hot.scatter_(1, bin_idx, 1 - bin_offset)
soft_two_hot.scatter_(1, (bin_idx + 1) % cfg.num_bins, bin_offset)
return soft_two_hot
def two_hot_inv(x, bins):
"""Converts a batch of soft two-hot encoded vectors to scalars."""
num_bins = bins.shape[0]