Refactor TD-MPC (#103)

Co-authored-by: Cadene <re.cadene@gmail.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Alexander Soare 2024-05-01 16:40:04 +01:00 committed by GitHub
parent a4891095e4
commit d1855a202a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1105 additions and 1205 deletions

View File

@ -22,8 +22,8 @@ test-end-to-end:
${MAKE} test-act-ete-eval ${MAKE} test-act-ete-eval
${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval ${MAKE} test-diffusion-ete-eval
# ${MAKE} test-tdmpc-ete-train ${MAKE} test-tdmpc-ete-train
# ${MAKE} test-tdmpc-ete-eval ${MAKE} test-tdmpc-ete-eval
${MAKE} test-default-ete-eval ${MAKE} test-default-ete-eval
test-act-ete-train: test-act-ete-train:
@ -74,8 +74,10 @@ test-tdmpc-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=tdmpc \ policy=tdmpc \
env=xarm \ env=xarm \
env.task=XarmLift-v0 \
dataset_repo_id=lerobot/xarm_lift_medium_replay \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=1 \ training.offline_steps=2 \
training.online_steps=2 \ training.online_steps=2 \
eval.n_episodes=1 \ eval.n_episodes=1 \
env.episode_length=2 \ env.episode_length=2 \

View File

@ -22,16 +22,17 @@ class ACTConfig:
The key represents the input data name, and the value is a list indicating the dimensions The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.images.top" refers to an input from the of the corresponding data. For example, "observation.images.top" refers to an input from the
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. "top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension. Importantly, shapes doesn't include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. 14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension.
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"), input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two availables and the value specifies the normalization mode to apply. The two available modes are "mean_std"
modes are "mean_std" which substracts the mean and divide by the standard which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
deviation and "min_max" which rescale in a [-1, 1] range. [-1, 1] range.
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale. output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images. vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone. pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
`None` means no pretrained weights. `None` means no pretrained weights.
@ -62,13 +63,13 @@ class ACTConfig:
chunk_size: int = 100 chunk_size: int = 100
n_action_steps: int = 100 n_action_steps: int = 100
input_shapes: dict[str, list[str]] = field( input_shapes: dict[str, list[int]] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.images.top": [3, 480, 640], "observation.images.top": [3, 480, 640],
"observation.state": [14], "observation.state": [14],
} }
) )
output_shapes: dict[str, list[str]] = field( output_shapes: dict[str, list[int]] = field(
default_factory=lambda: { default_factory=lambda: {
"action": [14], "action": [14],
} }

View File

@ -31,11 +31,17 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
name = "act" name = "act"
def __init__(self, config: ACTConfig | None = None, dataset_stats=None): def __init__(
self,
config: ACTConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
""" """
Args: Args:
config: Policy configuration class instance or None, in which case the default instantiation of config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used. the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
""" """
super().__init__() super().__init__()
if config is None: if config is None:
@ -58,7 +64,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self._action_queue = deque([], maxlen=self.config.n_action_steps) self._action_queue = deque([], maxlen=self.config.n_action_steps)
@torch.no_grad @torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations. """Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the This method wraps `select_actions` in order to return one action at a time for execution in the
@ -81,7 +87,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self._action_queue.extend(actions.transpose(0, 1)) self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft() return self._action_queue.popleft()
def forward(self, batch, **_) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class DiffusionConfig: class DiffusionConfig:
"""Configuration class for Diffusion Policy. """Configuration class for DiffusionPolicy.
Defaults are configured for training with PushT providing proprioceptive and single camera observations. Defaults are configured for training with PushT providing proprioceptive and single camera observations.
@ -25,11 +25,12 @@ class DiffusionConfig:
The key represents the output data name, and the value is a list indicating the dimensions The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"), input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two availables and the value specifies the normalization mode to apply. The two available modes are "mean_std"
modes are "mean_std" which substracts the mean and divide by the standard which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
deviation and "min_max" which rescale in a [-1, 1] range. [-1, 1] range.
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale. output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images. vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done. within the image size. If None, no cropping is done.
@ -70,13 +71,13 @@ class DiffusionConfig:
horizon: int = 16 horizon: int = 16
n_action_steps: int = 8 n_action_steps: int = 8
input_shapes: dict[str, list[str]] = field( input_shapes: dict[str, list[int]] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": [3, 96, 96], "observation.image": [3, 96, 96],
"observation.state": [2], "observation.state": [2],
} }
) )
output_shapes: dict[str, list[str]] = field( output_shapes: dict[str, list[int]] = field(
default_factory=lambda: { default_factory=lambda: {
"action": [2], "action": [2],
} }

View File

@ -43,15 +43,16 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
def __init__( def __init__(
self, self,
config: DiffusionConfig | None = None, config: DiffusionConfig | None = None,
dataset_stats=None, dataset_stats: dict[str, dict[str, Tensor]] | None = None,
): ):
""" """
Args: Args:
config: Policy configuration class instance or None, in which case the default instantiation of config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used. the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
""" """
super().__init__() super().__init__()
# TODO(alexander-soare): LR scheduler will be removed.
if config is None: if config is None:
config = DiffusionConfig() config = DiffusionConfig()
self.config = config self.config = config
@ -88,7 +89,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
} }
@torch.no_grad @torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations. """Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the This method handles caching a history of observations and an action trajectory generated by the
@ -136,7 +137,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
action = self._queues["action"].popleft() action = self._queues["action"].popleft()
return action return action
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)

View File

@ -24,7 +24,10 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
def get_policy_and_config_classes(name: str) -> tuple[Policy, object]: def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute).""" """Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
if name == "tdmpc": if name == "tdmpc":
raise NotImplementedError("Coming soon!") from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy, TDMPCConfig
elif name == "diffusion": elif name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

View File

@ -21,6 +21,14 @@ class Policy(Protocol):
name: str name: str
def __init__(self, cfg, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
dataset_stats: Dataset statistics to be used for normalization.
"""
def reset(self): def reset(self):
"""To be called whenever the environment is reset. """To be called whenever the environment is reset.
@ -39,3 +47,13 @@ class Policy(Protocol):
When the model uses a history of observations, or outputs a sequence of actions, this method deals When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching. with caching.
""" """
@runtime_checkable
class PolicyWithUpdate(Policy, Protocol):
def update(self):
"""An update method that is to be called after a training optimization step.
Implements an additional updates the model parameters may need (for example, doing an EMA step for a
target model, or incrementing an internal buffer).
"""

View File

@ -0,0 +1,150 @@
from dataclasses import dataclass, field
@dataclass
class TDMPCConfig:
"""Configuration class for TDMPCPolicy.
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift`.
Args:
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
action repeats in Q-learning or ask your favorite chatbot)
horizon: Horizon for model predictive control.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
match the original implementation.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
normalization mode here.
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
latent_dim: Observation's latent embedding dimension.
q_ensemble_size: Number of Q function estimators to use in an ensemble for uncertainty estimation.
mlp_dim: Hidden dimension of MLPs used for modelling the dynamics encoder, reward function, policy
(π), Q ensemble, and V.
discount: Discount factor (γ) to use for the reinforcement learning formalism.
use_mpc: Whether to use model predictive control. The alternative is to just sample the policy model
(π) for each step.
cem_iterations: Number of iterations for the MPPI/CEM loop in MPC.
max_std: Maximum standard deviation for actions sampled from the gaussian PDF in CEM.
min_std: Minimum standard deviation for noise applied to actions sampled from the policy model (π).
Doubles up as the minimum standard deviation for actions sampled from the gaussian PDF in CEM.
n_gaussian_samples: Number of samples to draw from the gaussian distribution every CEM iteration. Must
be non-zero.
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
be zero.
uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating
trajectory values (this is the λ coeffiecient in eqn 4 of FOWM).
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
elites, when updating the gaussian parameters for CEM.
gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
paramters optimized in CEM. Updates are calculated as μ αμ + (1-α)μ.
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
is applied. Note that the input images are assumed to be square for this augmentation.
reward_coeff: Loss weighting coefficient for the reward regression loss.
expectile_weight: Weighting (τ) used in expectile regression for the state value function (V).
v_pred < v_target is weighted by τ and v_pred >= v_target is weighted by (1-τ). τ is expected to
be in [0, 1]. Setting τ closer to 1 results in a more "optimistic" V. This is sensible to do
because v_target is obtained by evaluating the learned state-action value functions (Q) with
in-sample actions that may not be always optimal.
value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state
value (V) expectile regression loss.
consistency_coeff: Loss weighting coefficient for the consistency loss.
advantage_scaling: A factor by which the advantages are scaled prior to exponentiation for advantage
weighted regression of the policy (π) estimator parameters. Note that the exponentiated advantages
are clamped at 100.0.
pi_coeff: Loss weighting coefficient for the action regression loss.
temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time-
steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the
current time step.
target_model_momentum: Momentum (α) used for EMA updates of the target models. Updates are calculated
as ϕ αϕ + (1-α)θ where ϕ are the parameters of the target model and θ are the parameters of the
model being trained.
"""
# Input / output structure.
n_action_repeats: int = 2
horizon: int = 5
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
# Architecture / modeling.
# Neural networks.
image_encoder_hidden_dim: int = 32
state_encoder_hidden_dim: int = 256
latent_dim: int = 50
q_ensemble_size: int = 5
mlp_dim: int = 512
# Reinforcement learning.
discount: float = 0.9
# Inference.
use_mpc: bool = True
cem_iterations: int = 6
max_std: float = 2.0
min_std: float = 0.05
n_gaussian_samples: int = 512
n_pi_samples: int = 51
uncertainty_regularizer_coeff: float = 1.0
n_elites: int = 50
elite_weighting_temperature: float = 0.5
gaussian_mean_momentum: float = 0.1
# Training and loss computation.
max_random_shift_ratio: float = 0.0476
# Loss coefficients.
reward_coeff: float = 0.5
expectile_weight: float = 0.9
value_coeff: float = 0.1
consistency_coeff: float = 20.0
advantage_scaling: float = 3.0
pi_coeff: float = 0.5
temporal_decay_coeff: float = 0.5
# Target model.
target_model_momentum: float = 0.995
def __post_init__(self):
"""Input validation (not exhaustive)."""
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(
"Only square images are handled now. Got image shape "
f"{self.input_shapes['observation.image']}."
)
if self.n_gaussian_samples <= 0:
raise ValueError(
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
)
if self.output_normalization_modes != {"action": "min_max"}:
raise ValueError(
"TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
"information."
)

View File

@ -1,576 +0,0 @@
import os
import pickle
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal
DEFAULT_ACT_FN = nn.Mish()
def __REDUCE__(b): # noqa: N802, N807
return "mean" if b else "none"
def l1(pred, target, reduce=False):
"""Computes the L1-loss between predictions and targets."""
return F.l1_loss(pred, target, reduction=__REDUCE__(reduce))
def mse(pred, target, reduce=False):
"""Computes the MSE loss between predictions and targets."""
return F.mse_loss(pred, target, reduction=__REDUCE__(reduce))
def l2_expectile(diff, expectile=0.7, reduce=False):
weight = torch.where(diff > 0, expectile, (1 - expectile))
loss = weight * (diff**2)
reduction = __REDUCE__(reduce)
if reduction == "mean":
return torch.mean(loss)
elif reduction == "sum":
return torch.sum(loss)
return loss
def _get_out_shape(in_shape, layers):
"""Utility function. Returns the output shape of a network for a given input shape."""
x = torch.randn(*in_shape).unsqueeze(0)
return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape
def gaussian_logprob(eps, log_std):
"""Compute Gaussian log probability."""
residual = (-0.5 * eps.pow(2) - log_std).sum(-1, keepdim=True)
return residual - 0.5 * np.log(2 * np.pi) * eps.size(-1)
def squash(mu, pi, log_pi):
"""Apply squashing function."""
mu = torch.tanh(mu)
pi = torch.tanh(pi)
log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
return mu, pi, log_pi
def orthogonal_init(m):
"""Orthogonal layer initialization."""
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
gain = nn.init.calculate_gain("relu")
nn.init.orthogonal_(m.weight.data, gain)
if m.bias is not None:
nn.init.zeros_(m.bias)
def ema(m, m_target, tau):
"""Update slow-moving average of online network (target network) at rate tau."""
with torch.no_grad():
# TODO(rcadene, aliberts): issue with strict=False
# for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False):
# p_target.data.lerp_(p.data, tau)
m_params_iter = iter(m.parameters())
m_target_params_iter = iter(m_target.parameters())
while True:
try:
p = next(m_params_iter)
p_target = next(m_target_params_iter)
p_target.data.lerp_(p.data, tau)
except StopIteration:
# If any iterator is exhausted, exit the loop
break
def set_requires_grad(net, value):
"""Enable/disable gradients for a given (sub)network."""
for param in net.parameters():
param.requires_grad_(value)
class TruncatedNormal(pyd.Normal):
"""Utility class implementing the truncated normal distribution."""
default_sample_shape = torch.Size()
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
super().__init__(loc, scale, validate_args=False)
self.low = low
self.high = high
self.eps = eps
def _clamp(self, x):
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
x = x - x.detach() + clamped_x.detach()
return x
def sample(self, clip=None, sample_shape=default_sample_shape):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
eps *= self.scale
if clip is not None:
eps = torch.clamp(eps, -clip, clip)
x = self.loc + eps
return self._clamp(x)
class NormalizeImg(nn.Module):
"""Normalizes pixel observations to [0,1) range."""
def __init__(self):
super().__init__()
def forward(self, x):
return x.div(255.0)
class Flatten(nn.Module):
"""Flattens its input to a (batched) vector."""
def __init__(self):
super().__init__()
def forward(self, x):
return x.view(x.size(0), -1)
def enc(cfg):
obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size),
"state": (cfg.state_dim,),
}
"""Returns a TOLD encoder."""
pixels_enc_layers, state_enc_layers = None, None
if cfg.modality in {"pixels", "all"}:
C = int(3 * cfg.frame_stack) # noqa: N806
pixels_enc_layers = [
NormalizeImg(),
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
nn.ReLU(),
]
out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), pixels_enc_layers)
pixels_enc_layers.extend(
[
Flatten(),
nn.Linear(np.prod(out_shape), cfg.latent_dim),
nn.LayerNorm(cfg.latent_dim),
nn.Sigmoid(),
]
)
if cfg.modality == "pixels":
return ConvExt(nn.Sequential(*pixels_enc_layers))
if cfg.modality in {"state", "all"}:
state_dim = obs_shape[0] if cfg.modality == "state" else obs_shape["state"][0]
state_enc_layers = [
nn.Linear(state_dim, cfg.enc_dim),
nn.ELU(),
nn.Linear(cfg.enc_dim, cfg.latent_dim),
nn.LayerNorm(cfg.latent_dim),
nn.Sigmoid(),
]
if cfg.modality == "state":
return nn.Sequential(*state_enc_layers)
else:
raise NotImplementedError
encoders = {}
for k in obs_shape:
if k == "state":
encoders[k] = nn.Sequential(*state_enc_layers)
elif k.endswith("rgb"):
encoders[k] = ConvExt(nn.Sequential(*pixels_enc_layers))
else:
raise NotImplementedError
return Multiplexer(nn.ModuleDict(encoders))
def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
"""Returns an MLP."""
if isinstance(mlp_dim, int):
mlp_dim = [mlp_dim, mlp_dim]
return nn.Sequential(
nn.Linear(in_dim, mlp_dim[0]),
nn.LayerNorm(mlp_dim[0]),
act_fn,
nn.Linear(mlp_dim[0], mlp_dim[1]),
nn.LayerNorm(mlp_dim[1]),
act_fn,
nn.Linear(mlp_dim[1], out_dim),
)
def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
"""Returns a dynamics network."""
return nn.Sequential(
mlp(in_dim, mlp_dim, out_dim, act_fn),
nn.LayerNorm(out_dim),
nn.Sigmoid(),
)
def q(cfg):
action_dim = cfg.action_dim
"""Returns a Q-function that uses Layer Normalization."""
return nn.Sequential(
nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim),
nn.LayerNorm(cfg.mlp_dim),
nn.Tanh(),
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
nn.ELU(),
nn.Linear(cfg.mlp_dim, 1),
)
def v(cfg):
"""Returns a state value function that uses Layer Normalization."""
return nn.Sequential(
nn.Linear(cfg.latent_dim, cfg.mlp_dim),
nn.LayerNorm(cfg.mlp_dim),
nn.Tanh(),
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
nn.ELU(),
nn.Linear(cfg.mlp_dim, 1),
)
def aug(cfg):
obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size),
"state": (4,),
}
"""Multiplex augmentation"""
if cfg.modality == "state":
return nn.Identity()
elif cfg.modality == "pixels":
return RandomShiftsAug(cfg)
else:
augs = {}
for k in obs_shape:
if k == "state":
augs[k] = nn.Identity()
elif k.endswith("rgb"):
augs[k] = RandomShiftsAug(cfg)
else:
raise NotImplementedError
return Multiplexer(nn.ModuleDict(augs))
class ConvExt(nn.Module):
"""Auxiliary conv net accommodating high-dim input"""
def __init__(self, conv):
super().__init__()
self.conv = conv
def forward(self, x):
if x.ndim > 4:
batch_shape = x.shape[:-3]
out = self.conv(x.view(-1, *x.shape[-3:]))
out = out.view(*batch_shape, *out.shape[1:])
else:
out = self.conv(x)
return out
class Multiplexer(nn.Module):
"""Model multiplexer"""
def __init__(self, choices):
super().__init__()
self.choices = choices
def forward(self, x, key=None):
if isinstance(x, dict):
if key is not None:
return self.choices[key](x)
return {k: self.choices[k](_x) for k, _x in x.items()}
return self.choices(x)
class RandomShiftsAug(nn.Module):
"""
Random shift image augmentation.
Adapted from https://github.com/facebookresearch/drqv2
"""
def __init__(self, cfg):
super().__init__()
assert cfg.modality in {"pixels", "all"}
self.pad = int(cfg.img_size / 21)
def forward(self, x):
n, c, h, w = x.size()
assert h == w
padding = tuple([self.pad] * 4)
x = F.pad(x, padding, "replicate")
eps = 1.0 / (h + 2 * self.pad)
arange = torch.linspace(
-1.0 + eps,
1.0 - eps,
h + 2 * self.pad,
device=x.device,
dtype=torch.float32,
)[:h]
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
shift = torch.randint(
0,
2 * self.pad + 1,
size=(n, 1, 1, 2),
device=x.device,
dtype=torch.float32,
)
shift *= 2.0 / (h + 2 * self.pad)
grid = base_grid + shift
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
# TODO(aliberts): remove class
# class Episode:
# """Storage object for a single episode."""
# def __init__(self, cfg, init_obs):
# action_dim = cfg.action_dim
# self.cfg = cfg
# self.device = torch.device(cfg.buffer_device)
# if cfg.modality in {"pixels", "state"}:
# dtype = torch.float32 if cfg.modality == "state" else torch.uint8
# self.obses = torch.empty(
# (cfg.episode_length + 1, *init_obs.shape),
# dtype=dtype,
# device=self.device,
# )
# self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
# elif cfg.modality == "all":
# self.obses = {}
# for k, v in init_obs.items():
# assert k in {"rgb", "state"}
# dtype = torch.float32 if k == "state" else torch.uint8
# self.obses[k] = torch.empty(
# (cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device
# )
# self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
# else:
# raise ValueError
# self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device)
# self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
# self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device)
# self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
# self.cumulative_reward = 0
# self.done = False
# self.success = False
# self._idx = 0
# def __len__(self):
# return self._idx
# @classmethod
# def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None):
# """Constructs an episode from a trajectory."""
# if cfg.modality in {"pixels", "state"}:
# episode = cls(cfg, obses[0])
# episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device)
# elif cfg.modality == "all":
# episode = cls(cfg, {k: v[0] for k, v in obses.items()})
# for k in obses:
# episode.obses[k][1:] = torch.tensor(
# obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
# )
# else:
# raise NotImplementedError
# episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device)
# episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device)
# episode.dones = (
# torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
# if dones is not None
# else torch.zeros_like(episode.dones)
# )
# episode.masks = (
# torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device)
# if masks is not None
# else torch.ones_like(episode.masks)
# )
# episode.cumulative_reward = torch.sum(episode.rewards)
# episode.done = True
# episode._idx = cfg.episode_length
# return episode
# @property
# def first(self):
# return len(self) == 0
# def __add__(self, transition):
# self.add(*transition)
# return self
# def add(self, obs, action, reward, done, mask=1.0, success=False):
# """Add a transition into the episode."""
# if isinstance(obs, dict):
# for k, v in obs.items():
# self.obses[k][self._idx + 1] = torch.tensor(
# v, dtype=self.obses[k].dtype, device=self.obses[k].device
# )
# else:
# self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device)
# self.actions[self._idx] = action
# self.rewards[self._idx] = reward
# self.dones[self._idx] = done
# self.masks[self._idx] = mask
# self.cumulative_reward += reward
# self.done = done
# self.success = self.success or success
# self._idx += 1
def get_dataset_dict(cfg, env, return_reward_normalizer=False):
"""Construct a dataset for env"""
required_keys = [
"observations",
"next_observations",
"actions",
"rewards",
"dones",
"masks",
]
if cfg.task.startswith("xarm"):
dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl")
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
for k in required_keys:
if k not in dataset_dict and k[:-1] in dataset_dict:
dataset_dict[k] = dataset_dict.pop(k[:-1])
elif cfg.task.startswith("legged"):
dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl")
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
dataset_dict["actions"] /= env.unwrapped.clip_actions
print(f"clip_actions={env.unwrapped.clip_actions}")
else:
import d4rl
dataset_dict = d4rl.qlearning_dataset(env)
dones = np.full_like(dataset_dict["rewards"], False, dtype=bool)
for i in range(len(dones) - 1):
if (
np.linalg.norm(dataset_dict["observations"][i + 1] - dataset_dict["next_observations"][i])
> 1e-6
or dataset_dict["terminals"][i] == 1.0
):
dones[i] = True
dones[-1] = True
dataset_dict["masks"] = 1.0 - dataset_dict["terminals"]
del dataset_dict["terminals"]
for k, v in dataset_dict.items():
dataset_dict[k] = v.astype(np.float32)
dataset_dict["dones"] = dones
if cfg.is_data_clip:
lim = 1 - cfg.data_clip_eps
dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim)
reward_normalizer = get_reward_normalizer(cfg, dataset_dict)
dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"])
for key in required_keys:
assert key in dataset_dict, f"Missing `{key}` in dataset."
if return_reward_normalizer:
return dataset_dict, reward_normalizer
return dataset_dict
def get_trajectory_boundaries_and_returns(dataset):
"""
Split dataset into trajectories and compute returns
"""
episode_starts = [0]
episode_ends = []
episode_return = 0
episode_returns = []
n_transitions = len(dataset["rewards"])
for i in range(n_transitions):
episode_return += dataset["rewards"][i]
if dataset["dones"][i]:
episode_returns.append(episode_return)
episode_ends.append(i + 1)
if i + 1 < n_transitions:
episode_starts.append(i + 1)
episode_return = 0.0
return episode_starts, episode_ends, episode_returns
def normalize_returns(dataset, scaling=1000):
"""
Normalize returns in the dataset
"""
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
dataset["rewards"] /= np.max(episode_returns) - np.min(episode_returns)
dataset["rewards"] *= scaling
return dataset
def get_reward_normalizer(cfg, dataset):
"""
Get a reward normalizer for the dataset
"""
if cfg.task.startswith("xarm"):
return lambda x: x
elif "maze" in cfg.task:
return lambda x: x - 1.0
elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]:
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
return lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0
elif hasattr(cfg, "reward_scale"):
return lambda x: x * cfg.reward_scale
return lambda x: x
def linear_schedule(schdl, step):
"""
Outputs values following a linear decay schedule.
Adapted from https://github.com/facebookresearch/drqv2
"""
try:
return float(schdl)
except ValueError:
match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl)
if match:
init, final, start, end = (float(g) for g in match.groups())
mix = np.clip((step - start) / (end - start), 0.0, 1.0)
return (1.0 - mix) * init + mix * final
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
if match:
init, final, duration = (float(g) for g in match.groups())
mix = np.clip(step / duration, 0.0, 1.0)
return (1.0 - mix) * init + mix * final
raise NotImplementedError(schdl)

View File

@ -0,0 +1,798 @@
"""Implementation of Finetuning Offline World Models in the Real World.
The comments in this code may sometimes refer to these references:
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955)
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029)
TODO(alexander-soare): Make rollout work for batch sizes larger than 1.
TODO(alexander-soare): Use batch-first throughout.
"""
# ruff: noqa: N806
import logging
from collections import deque
from copy import deepcopy
from functools import partial
from typing import Callable
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
"""Implementation of TD-MPC learning + inference.
Please note several warnings for this policy.
- Evaluation of pretrained weights created with the original FOWM code
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
process communication to use the xarm environment from FOWM. This is because our xarm
environment uses newer dependencies and does not match the environment in FOWM. See
https://github.com/huggingface/lerobot/pull/103 for implementation details.
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
match our xarm environment.
"""
name = "tdmpc"
def __init__(
self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
logging.warning(
"""
Please note several warnings for this policy.
- Evaluation of pretrained weights created with the original FOWM code
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
process communication to use the xarm environment from FOWM. This is because our xarm
environment uses newer dependencies and does not match the environment in FOWM. See
https://github.com/huggingface/lerobot/pull/103 for implementation details.
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
match our xarm environment.
"""
)
if config is None:
config = TDMPCConfig()
self.config = config
self.model = TDMPCTOLD(config)
self.model_target = deepcopy(self.model)
self.model_target.eval()
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
else:
self.normalize_inputs = nn.Identity()
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
def save(self, fp):
"""Save state dict of TOLD model to filepath."""
torch.save(self.state_dict(), fp)
def load(self, fp):
"""Load a saved state dict from filepath into current agent."""
self.load_state_dict(torch.load(fp))
def reset(self):
"""
Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be
called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=1),
"observation.state": deque(maxlen=1),
"action": deque(maxlen=self.config.n_action_repeats),
}
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
# CEM for the next step.
self._prev_mean: torch.Tensor | None = None
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]):
"""Select a single action given environment observations."""
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2
batch = self.normalize_inputs(batch)
self._queues = populate_queues(self._queues, batch)
# When the action queue is depleted, populate it again by querying the policy.
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
# Remove the time dimensions as it is not handled yet.
for key in batch:
assert batch[key].shape[1] == 1
batch[key] = batch[key][:, 0]
# NOTE: Order of observations matters here.
z = self.model.encode({k: batch[k] for k in ["observation.image", "observation.state"]})
if self.config.use_mpc:
batch_size = batch["observation.image"].shape[0]
# Batch processing is not handled in MPC mode, so process the batch in a loop.
action = [] # will be a batch of actions for one step
for i in range(batch_size):
# Note: self.plan does not handle batches, hence the squeeze.
action.append(self.plan(z[i]))
action = torch.stack(action)
else:
# Plan with the policy (π) alone.
action = self.model.pi(z)
self.unnormalize_outputs({"action": action})["action"]
for _ in range(self.config.n_action_repeats):
self._queues["action"].append(action)
action = self._queues["action"].popleft()
return torch.clamp(action, -1, 1)
@torch.no_grad()
def plan(self, z: Tensor) -> Tensor:
"""Plan next action using TD-MPC inference.
Args:
z: (latent_dim,) tensor for the initial state.
Returns:
(action_dim,) tensor for the next action.
TODO(alexander-soare) Extend this to be able to work with batches.
"""
device = get_device_from_parameters(self)
# Sample Nπ trajectories from the policy.
pi_actions = torch.empty(
self.config.horizon,
self.config.n_pi_samples,
self.config.output_shapes["action"][0],
device=device,
)
if self.config.n_pi_samples > 0:
_z = einops.repeat(z, "d -> n d", n=self.config.n_pi_samples)
for t in range(self.config.horizon):
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
# helpful for CEM.
pi_actions[t] = self.model.pi(_z, self.config.min_std)
_z = self.model.latent_dynamics(_z, pi_actions[t])
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
# trajectories.
z = einops.repeat(z, "d -> n d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
# algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros(self.config.horizon, self.config.output_shapes["action"][0], device=device)
# Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None:
mean[:-1] = self._prev_mean[1:]
std = self.config.max_std * torch.ones_like(mean)
for _ in range(self.config.cem_iterations):
# Randomly sample action trajectories for the gaussian distribution.
std_normal_noise = torch.randn(
self.config.horizon,
self.config.n_gaussian_samples,
self.config.output_shapes["action"][0],
device=std.device,
)
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
# Compute elite actions.
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
value = self.estimate_value(z, actions).nan_to_num_(0)
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Update guassian PDF parameters to be the (weighted) mean and standard deviation of the elites.
max_value = elite_value.max(0)[0]
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score /= score.sum()
_mean = torch.sum(einops.rearrange(score, "n -> n 1") * elite_actions, dim=1)
_std = torch.sqrt(
torch.sum(
einops.rearrange(score, "n -> n 1")
* (elite_actions - einops.rearrange(_mean, "h d -> h 1 d")) ** 2,
dim=1,
)
)
# Update mean with an exponential moving average, and std with a direct replacement.
mean = (
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
)
std = _std.clamp_(self.config.min_std, self.config.max_std)
# Keep track of the mean for warm-starting subsequent steps.
self._prev_mean = mean
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
# scores from the last iteration.
actions = elite_actions[:, torch.multinomial(score, 1).item()]
# Select only the first action
action = actions[0]
return action
@torch.no_grad()
def estimate_value(self, z: Tensor, actions: Tensor):
"""Estimates the value of a trajectory as per eqn 4 of the FOWM paper.
Args:
z: (batch, latent_dim) tensor of initial latent states.
actions: (horizon, batch, action_dim) tensor of action trajectories.
Returns:
(batch,) tensor of values.
"""
# Initialize return and running discount factor.
G, running_discount = 0, 1
# Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics
# model. Keep track of return.
for t in range(actions.shape[0]):
# We will compute the reward in a moment. First compute the uncertainty regularizer from eqn 4
# of the FOWM paper.
if self.config.uncertainty_regularizer_coeff > 0:
regularization = -(
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
)
else:
regularization = 0
# Estimate the next state (latent) and reward.
z, reward = self.model.latent_dynamics_and_reward(z, actions[t])
# Update the return and running discount.
G += running_discount * (reward + regularization)
running_discount *= self.config.discount
# Add the estimated value of the final state (using the minimum for a conservative estimate).
# Do so by predicting the next action, then taking a minimum over the ensemble of state-action value
# estimators.
# Note: This small amount of added noise seems to help a bit at inference time as observed by success
# metrics over 50 episodes of xarm_lift_medium_replay.
next_action = self.model.pi(z, self.config.min_std) # (batch, action_dim)
terminal_values = self.model.Qs(z, next_action) # (ensemble, batch)
# Randomly choose 2 of the Qs for terminal value estimation (as in App C. of the FOWM paper).
if self.config.q_ensemble_size > 2:
G += (
running_discount
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
0
]
)
else:
G += running_discount * torch.min(terminal_values, dim=0)[0]
# Finally, also regularize the terminal value.
if self.config.uncertainty_regularizer_coeff > 0:
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
return G
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss."""
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
info = {}
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
batch_size = batch["index"].shape[0]
# (b, t) -> (t, b)
for key in batch:
if batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
action = batch["action"] # (t, b)
reward = batch["next.reward"] # (t,)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations.
if self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations["observation.image"],
)
# Get the current observation for predicting trajectories, and all future observations for use in
# the latent consistency loss and TD loss.
current_observation, next_observations = {}, {}
for k in observations:
current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:]
horizon = next_observations["observation.image"].shape[0]
# Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)
for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
# Compute Q and V value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
v_preds = self.model.V(z_preds[:-1])
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
# Compute various targets with stopgrad.
with torch.no_grad():
# Latent state consistency targets.
z_targets = self.model_target.encode(next_observations)
# State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the
# learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM
# uses a learned state value function: V(z). This means the TD targets only depend on in-sample
# actions (not actions estimated by π).
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
# and the FOWM paper.
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
# are using them to compute loss for V.
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
# Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
# future have less impact on the loss. Note: unsqueeze will let us broadcast to (seq, batch).
temporal_loss_coeffs = torch.pow(
self.config.temporal_decay_coeff, torch.arange(horizon, device=device)
).unsqueeze(-1)
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
# predicted from the (target model's) observation encoder.
consistency_loss = (
(
temporal_loss_coeffs
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
# `z_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# `z_targets` depends on the next observation.
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
# rewards.
reward_loss = (
(
temporal_loss_coeffs
* F.mse_loss(reward_preds, reward, reduction="none")
* ~batch["next.reward_is_pad"]
# `reward_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
)
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
q_value_loss = (
(
F.mse_loss(
q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
# Compute state value loss as in eqn 3 of FOWM.
diff = v_targets - v_preds
# Expectile loss penalizes:
# - `v_preds < v_targets` with weighting `expectile_weight`
# - `v_preds >= v_targets` with weighting `1 - expectile_weight`
raw_v_value_loss = torch.where(
diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight)
) * (diff**2)
v_value_loss = (
(
temporal_loss_coeffs
* raw_v_value_loss
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
)
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
# We won't need these gradients again so detach.
z_preds = z_preds.detach()
# Use stopgrad for the advantage calculation.
with torch.no_grad():
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
z_preds[:-1]
)
info["advantage"] = advantage[0]
# (t, b)
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
# gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying
# the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset
# as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration
# parameter for it (see below where we compute the total loss).
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
# other losses.
# TODO(alexander-soare): Take the sum over the temporal dimension and check that training still works
# as well as expected.
pi_loss = (
exp_advantage
* mse
* temporal_loss_coeffs
# `action_preds` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).mean()
loss = (
self.config.consistency_coeff * consistency_loss
+ self.config.reward_coeff * reward_loss
+ self.config.value_coeff * q_value_loss
+ self.config.value_coeff * v_value_loss
+ self.config.pi_coeff * pi_loss
)
info.update(
{
"consistency_loss": consistency_loss.item(),
"reward_loss": reward_loss.item(),
"Q_value_loss": q_value_loss.item(),
"V_value_loss": v_value_loss.item(),
"pi_loss": pi_loss.item(),
"loss": loss,
"sum_loss": loss.item() * self.config.horizon,
}
)
# Undo (b, t) -> (t, b).
for key in batch:
if batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
return info
def update(self):
"""Update the target model's parameters with an EMA step."""
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
class TDMPCTOLD(nn.Module):
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
def __init__(self, config: TDMPCConfig):
super().__init__()
self.config = config
self._encoder = TDMPCObservationEncoder(config)
self._dynamics = nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
self._reward = nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, 1),
)
self._pi = nn.Sequential(
nn.Linear(config.latent_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.output_shapes["action"][0]),
)
self._Qs = nn.ModuleList(
[
nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.ELU(),
nn.Linear(config.mlp_dim, 1),
)
for _ in range(config.q_ensemble_size)
]
)
self._V = nn.Sequential(
nn.Linear(config.latent_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.ELU(),
nn.Linear(config.mlp_dim, 1),
)
self._init_weights()
def _init_weights(self):
"""Initialize model weights.
Orthogonal initialization for all linear and convolutional layers' weights (apart from final layers
of reward network and Q networks which get zero initialization).
Zero initialization for all linear and convolutional layers' biases.
"""
def _apply_fn(m):
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
gain = nn.init.calculate_gain("relu")
nn.init.orthogonal_(m.weight.data, gain)
if m.bias is not None:
nn.init.zeros_(m.bias)
self.apply(_apply_fn)
for m in [self._reward, *self._Qs]:
assert isinstance(
m[-1], nn.Linear
), "Sanity check. The last linear layer needs 0 initialization on weights."
nn.init.zeros_(m[-1].weight)
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation."""
return self._encoder(obs)
def latent_dynamics_and_reward(self, z: Tensor, a: Tensor) -> tuple[Tensor, Tensor]:
"""Predict the next state's latent representation and the reward given a current latent and action.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
a: (*, action_dim) tensor for the action to be applied.
Returns:
A tuple containing:
- (*, latent_dim) tensor for the next state's latent representation.
- (*,) tensor for the estimated reward.
"""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x), self._reward(x).squeeze(-1)
def latent_dynamics(self, z: Tensor, a: Tensor) -> Tensor:
"""Predict the next state's latent representation given a current latent and action.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
a: (*, action_dim) tensor for the action to be applied.
Returns:
(*, latent_dim) tensor for the next state's latent representation.
"""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x)
def pi(self, z: Tensor, std: float = 0.0) -> Tensor:
"""Samples an action from the learned policy.
The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when
generating rollouts for online training.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
std: The standard deviation of the injected noise.
Returns:
(*, action_dim) tensor for the sampled action.
"""
action = torch.tanh(self._pi(z))
if std > 0:
std = torch.ones_like(action) * std
action += torch.randn_like(action) * std
return action
def V(self, z: Tensor) -> Tensor: # noqa: N802
"""Predict state value (V).
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
Returns:
(*,) tensor of estimated state values.
"""
return self._V(z).squeeze(-1)
def Qs(self, z: Tensor, a: Tensor, return_min: bool = False) -> Tensor: # noqa: N802
"""Predict state-action value for all of the learned Q functions.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
a: (*, action_dim) tensor for the action to be applied.
return_min: Set to true for implementing the detail in App. C of the FOWM paper: randomly select
2 of the Qs and return the minimum
Returns:
(q_ensemble, *) tensor for the value predictions of each learned Q function in the ensemble OR
(*,) tensor if return_min=True.
"""
x = torch.cat([z, a], dim=-1)
if not return_min:
return torch.stack([q(x).squeeze(-1) for q in self._Qs], dim=0)
else:
if len(self._Qs) > 2: # noqa: SIM108
Qs = [self._Qs[i] for i in np.random.choice(len(self._Qs), size=2)]
else:
Qs = self._Qs
return torch.stack([q(x).squeeze(-1) for q in Qs], dim=0).min(dim=0)[0]
class TDMPCObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: TDMPCConfig):
"""
Creates encoders for pixel and/or state modalities.
TODO(alexander-soare): The original work allows for multiple images by concatenating them along the
channel dimension. Re-implement this capability.
"""
super().__init__()
self.config = config
if "observation.image" in config.input_shapes:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
)
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
if "observation.image" in self.config.input_shapes:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
return torch.stack(feat, dim=0).mean(0)
def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor:
"""Randomly shifts images horizontally and vertically.
Adapted from https://github.com/facebookresearch/drqv2
"""
b, _, h, w = x.size()
assert h == w, "non-square images not handled yet"
pad = int(round(max_random_shift_ratio * h))
x = F.pad(x, tuple([pad] * 4), "replicate")
eps = 1.0 / (h + 2 * pad)
arange = torch.linspace(
-1.0 + eps,
1.0 - eps,
h + 2 * pad,
device=x.device,
dtype=torch.float32,
)[:h]
arange = einops.repeat(arange, "w -> h w 1", h=h)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = einops.repeat(base_grid, "h w c -> b h w c", b=b)
# A random shift in units of pixels and within the boundaries of the padding.
shift = torch.randint(
0,
2 * pad + 1,
size=(b, 1, 1, 2),
device=x.device,
dtype=torch.float32,
)
shift *= 2.0 / (h + 2 * pad)
grid = base_grid + shift
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
for (n_p_ema, p_ema), (n_p, p) in zip(
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
# Copy BatchNorm parameters, and non-trainable parameters directly.
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
with torch.no_grad():
p_ema.mul_(alpha)
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
(B, *), where * is any number of dimensions.
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions, generally
different from *.
Returns:
A return value from the callable reshaped to (**, *).
"""
if image_tensor.ndim == 4:
return fn(image_tensor)
start_dims = image_tensor.shape[:-3]
inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))

View File

@ -1,495 +0,0 @@
# ruff: noqa: N806
import time
from collections import deque
from copy import deepcopy
import einops
import numpy as np
import torch
import torch.nn as nn
import lerobot.common.policies.tdmpc.helper as h
from lerobot.common.policies.utils import populate_queues
from lerobot.common.utils.utils import get_safe_torch_device
FIRST_FRAME = 0
class TOLD(nn.Module):
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
def __init__(self, cfg):
super().__init__()
action_dim = cfg.action_dim
self.cfg = cfg
self._encoder = h.enc(cfg)
self._dynamics = h.dynamics(cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim)
self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1)
self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim)
self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)])
self._V = h.v(cfg)
self.apply(h.orthogonal_init)
for m in [self._reward, *self._Qs]:
m[-1].weight.data.fill_(0)
m[-1].bias.data.fill_(0)
def track_q_grad(self, enable=True):
"""Utility function. Enables/disables gradient tracking of Q-networks."""
for m in self._Qs:
h.set_requires_grad(m, enable)
def track_v_grad(self, enable=True):
"""Utility function. Enables/disables gradient tracking of Q-networks."""
if hasattr(self, "_V"):
h.set_requires_grad(self._V, enable)
def encode(self, obs):
"""Encodes an observation into its latent representation."""
out = self._encoder(obs)
if isinstance(obs, dict):
# fusion
out = torch.stack([v for k, v in out.items()]).mean(dim=0)
return out
def next(self, z, a):
"""Predicts next latent state (d) and single-step reward (R)."""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x), self._reward(x)
def next_dynamics(self, z, a):
"""Predicts next latent state (d)."""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x)
def pi(self, z, std=0):
"""Samples an action from the learned policy (pi)."""
mu = torch.tanh(self._pi(z))
if std > 0:
std = torch.ones_like(mu) * std
return h.TruncatedNormal(mu, std).sample(clip=0.3)
return mu
def V(self, z): # noqa: N802
"""Predict state value (V)."""
return self._V(z)
def Q(self, z, a, return_type): # noqa: N802
"""Predict state-action value (Q)."""
assert return_type in {"min", "avg", "all"}
x = torch.cat([z, a], dim=-1)
if return_type == "all":
return torch.stack([q(x) for q in self._Qs], dim=0)
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
class TDMPCPolicy(nn.Module):
"""Implementation of TD-MPC learning + inference."""
name = "tdmpc"
def __init__(self, cfg, n_obs_steps, n_action_steps, device):
super().__init__()
self.action_dim = cfg.action_dim
self.cfg = cfg
self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps
self.device = get_safe_torch_device(device)
self.std = h.linear_schedule(cfg.std_schedule, 0)
self.model = TOLD(cfg)
self.model.to(self.device)
self.model_target = deepcopy(self.model)
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.model.eval()
self.model_target.eval()
self.register_buffer("step", torch.zeros(1))
def state_dict(self):
"""Retrieve state dict of TOLD model, including slow-moving target network."""
return {
"model": self.model.state_dict(),
"model_target": self.model_target.state_dict(),
}
def save(self, fp):
"""Save state dict of TOLD model to filepath."""
torch.save(self.state_dict(), fp)
def load(self, fp):
"""Load a saved state dict from filepath into current agent."""
d = torch.load(fp)
self.model.load_state_dict(d["model"])
self.model_target.load_state_dict(d["model_target"])
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=self.n_obs_steps),
"observation.state": deque(maxlen=self.n_obs_steps),
"action": deque(maxlen=self.n_action_steps),
}
@torch.no_grad()
def select_action(self, batch, step):
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2
self._queues = populate_queues(self._queues, batch)
t0 = step == 0
self.eval()
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
if self.n_obs_steps == 1:
# hack to remove the time dimension
for key in batch:
assert batch[key].shape[1] == 1
batch[key] = batch[key][:, 0]
actions = []
batch_size = batch["observation.image"].shape[0]
for i in range(batch_size):
obs = {
"rgb": batch["observation.image"][[i]],
"state": batch["observation.state"][[i]],
}
# Note: unsqueeze needed because `act` still uses non-batch logic.
action = self.act(obs, t0=t0, step=self.step)
actions.append(action)
action = torch.stack(actions)
# tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time
if i in range(self.n_action_steps):
self._queues["action"].append(action)
action = self._queues["action"].popleft()
return action
@torch.no_grad()
def act(self, obs, t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
z = self.model.encode(obs)
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)
else:
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
return a
@torch.no_grad()
def estimate_value(self, z, actions, horizon):
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1
for t in range(horizon):
if self.cfg.uncertainty_cost > 0:
G -= (
discount
* self.cfg.uncertainty_cost
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
)
z, reward = self.model.next(z, actions[t])
G += discount * reward
discount *= self.cfg.discount
pi = self.model.pi(z, self.cfg.min_std)
G += discount * self.model.Q(z, pi, return_type="min")
if self.cfg.uncertainty_cost > 0:
G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0)
return G
@torch.no_grad()
def plan(self, z, step=None, t0=True):
"""
Plan next action using TD-MPC inference.
z: latent state.
step: current time step. determines e.g. planning horizon.
t0: whether current step is the first step of an episode.
"""
# during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
assert step is not None
# Seed steps
if step < self.cfg.seed_steps and self.model.training:
return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)
# Sample policy trajectories
horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
if num_pi_trajs > 0:
pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device)
_z = z.repeat(num_pi_trajs, 1)
for t in range(horizon):
pi_actions[t] = self.model.pi(_z, self.cfg.min_std)
_z = self.model.next_dynamics(_z, pi_actions[t])
# Initialize state and parameters
z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1)
mean = torch.zeros(horizon, self.action_dim, device=self.device)
std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device)
if not t0 and hasattr(self, "_prev_mean"):
mean[:-1] = self._prev_mean[1:]
# Iterate CEM
for _ in range(self.cfg.iterations):
actions = torch.clamp(
mean.unsqueeze(1)
+ std.unsqueeze(1)
* torch.randn(horizon, self.cfg.num_samples, self.action_dim, device=std.device),
-1,
1,
)
if num_pi_trajs > 0:
actions = torch.cat([actions, pi_actions], dim=1)
# Compute elite actions
value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Update parameters
max_value = elite_value.max(0)[0]
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
score /= score.sum(0)
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
_std = torch.sqrt(
torch.sum(
score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2,
dim=1,
)
/ (score.sum(0) + 1e-9)
)
_std = _std.clamp_(self.std, self.cfg.max_std)
mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
# Outputs
# TODO(rcadene): remove numpy with
# # Convert score tensor to probabilities using softmax
# probabilities = torch.softmax(score, dim=0)
# # Generate a random sample index based on the probabilities
# sample_index = torch.multinomial(probabilities, 1).item()
score = score.squeeze(1).cpu().numpy()
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
self._prev_mean = mean
mean, std = actions[0], _std[0]
a = mean
if self.model.training:
a += std * torch.randn(self.action_dim, device=std.device)
return torch.clamp(a, -1, 1)
def update_pi(self, zs, acts=None):
"""Update policy using a sequence of latent states."""
self.pi_optim.zero_grad(set_to_none=True)
self.model.track_q_grad(False)
self.model.track_v_grad(False)
info = {}
# Advantage Weighted Regression
assert acts is not None
vs = self.model.V(zs)
qs = self.model_target.Q(zs, acts, return_type="min")
adv = qs - vs
exp_a = torch.exp(adv * self.cfg.A_scaling)
exp_a = torch.clamp(exp_a, max=100.0)
log_probs = h.gaussian_logprob(self.model.pi(zs) - acts, 0)
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
pi_loss = -((exp_a * log_probs).mean(dim=(1, 2)) * rho).mean()
info["adv"] = adv[0]
pi_loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model._pi.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.pi_optim.step()
self.model.track_q_grad(True)
self.model.track_v_grad(True)
info["pi_loss"] = pi_loss.item()
return pi_loss.item(), info
@torch.no_grad()
def _td_target(self, next_z, reward, mask):
"""Compute the TD-target from a reward and the observation at the following time step."""
next_v = self.model.V(next_z)
td_target = reward + self.cfg.discount * mask * next_v.squeeze(2)
return td_target
def forward(self, batch, step):
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
raise NotImplementedError()
def update(self, batch, step):
"""Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time()
batch_size = batch["index"].shape[0]
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
# batch size b = 256, time/horizon t = 5
# b t ... -> t b ...
for key in batch:
if batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
action = batch["action"]
reward = batch["next.reward"]
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device)
obses = {
"rgb": batch["observation.image"],
"state": batch["observation.state"],
}
shapes = {}
for k in obses:
shapes[k] = obses[k].shape
obses[k] = einops.rearrange(obses[k], "t b ... -> (t b) ... ")
# Apply augmentations
aug_tf = h.aug(self.cfg)
obses = aug_tf(obses)
for k in obses:
t, b = shapes[k][:2]
obses[k] = einops.rearrange(obses[k], "(t b) ... -> t b ... ", b=b, t=t)
obs, next_obses = {}, {}
for k in obses:
obs[k] = obses[k][0]
next_obses[k] = obses[k][1:].clone()
horizon = next_obses["rgb"].shape[0]
loss_mask = torch.ones_like(mask, device=self.device)
for t in range(1, horizon):
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
self.optim.zero_grad(set_to_none=True)
self.std = h.linear_schedule(self.cfg.std_schedule, step)
self.model.train()
data_s = time.time() - start_time
# Compute targets
with torch.no_grad():
next_z = self.model.encode(next_obses)
z_targets = self.model_target.encode(next_obses)
td_targets = self._td_target(next_z, reward, mask)
# Latent rollout
zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device)
reward_preds = torch.empty_like(reward, device=self.device)
assert reward.shape[0] == horizon
z = self.model.encode(obs)
zs[0] = z
value_info = {"Q": 0.0, "V": 0.0}
for t in range(horizon):
z, reward_pred = self.model.next(z, action[t])
zs[t + 1] = z
reward_preds[t] = reward_pred.squeeze(1)
with torch.no_grad():
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
# Predictions
qs = self.model.Q(zs[:-1], action, return_type="all")
qs = qs.squeeze(3)
value_info["Q"] = qs.mean().item()
v = self.model.V(zs[:-1])
value_info["V"] = v.mean().item()
# Losses
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1)
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0)
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
q_value_loss, priority_loss = 0, 0
for q in range(self.cfg.num_q):
q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0)
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
expectile = h.linear_schedule(self.cfg.expectile, step)
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum(
dim=0
)
total_loss = (
self.cfg.consistency_coef * consistency_loss
+ self.cfg.reward_coef * reward_loss
+ self.cfg.value_coef * q_value_loss
+ self.cfg.value_coef * v_value_loss
)
weighted_loss = (total_loss * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
has_nan = torch.isnan(weighted_loss).item()
if has_nan:
print(f"weighted_loss has nan: {total_loss=} {weights=}")
else:
weighted_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
)
self.optim.step()
# TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion
# if self.cfg.per:
# # Update priorities
# priorities = priority_loss.clamp(max=1e4).detach()
# has_nan = torch.isnan(priorities).any().item()
# if has_nan:
# print(f"priorities has nan: {priorities=}")
# else:
# replay_buffer.update_priority(
# idxs[:num_slices],
# priorities[:num_slices],
# )
# if demo_batch_size > 0:
# demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
# Update policy + target network
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
if step % self.cfg.update_freq == 0:
h.ema(self.model._encoder, self.model_target._encoder, self.cfg.tau)
h.ema(self.model._Qs, self.model_target._Qs, self.cfg.tau)
self.model.eval()
info = {
"consistency_loss": float(consistency_loss.mean().item()),
"reward_loss": float(reward_loss.mean().item()),
"Q_value_loss": float(q_value_loss.mean().item()),
"V_value_loss": float(v_value_loss.mean().item()),
"sum_loss": float(total_loss.mean().item()),
"loss": float(weighted_loss.mean().item()),
"grad_norm": float(grad_norm),
"lr": self.cfg.lr,
"data_s": data_s,
"update_s": time.time() - start_time,
}
# info["demo_batch_size"] = demo_batch_size
info["expectile"] = expectile
info.update(value_info)
info.update(pi_update_info)
self.step[0] = step
return info

View File

@ -17,6 +17,7 @@ training:
offline_steps: ??? offline_steps: ???
online_steps: ??? online_steps: ???
online_steps_between_rollouts: ??? online_steps_between_rollouts: ???
online_sampling_ratio: 0.5
eval_freq: ??? eval_freq: ???
save_freq: ??? save_freq: ???
log_freq: 250 log_freq: 250

View File

@ -1,85 +1,76 @@
# @package _global_ # @package _global_
n_action_steps: 2 seed: 1
n_obs_steps: 1
training:
offline_steps: 25000
online_steps: 25000
eval_freq: 5000
online_steps_between_rollouts: 1
online_sampling_ratio: 0.5
batch_size: 256
grad_clip_norm: 10.0
lr: 3e-4
delta_timestamps:
observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
action: "[i / ${fps} for i in range(${policy.horizon})]"
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
policy: policy:
name: tdmpc name: tdmpc
reward_scale: 1.0 pretrained_model_path:
episode_length: ${env.episode_length} # Input / output structure.
discount: 0.9 n_action_repeats: 2
modality: 'all'
# pixels
frame_stack: 1
num_channels: 32
img_size: ${env.image_size}
state_dim: ${env.action_dim}
action_dim: ${env.action_dim}
# planning
mpc: true
iterations: 6
num_samples: 512
num_elites: 50
mixture_coef: 0.1
min_std: 0.05
max_std: 2.0
temperature: 0.5
momentum: 0.1
uncertainty_cost: 1
# actor
log_std_min: -10
log_std_max: 2
# learning
batch_size: 256
max_buffer_size: 10000
horizon: 5 horizon: 5
reward_coef: 0.5
value_coef: 0.1
consistency_coef: 20
rho: 0.5
kappa: 0.1
lr: 3e-4
std_schedule: ${policy.min_std}
horizon_schedule: ${policy.horizon}
per: true
per_alpha: 0.6
per_beta: 0.4
grad_clip_norm: 10
seed_steps: 0
update_freq: 2
tau: 0.01
online_steps_between_rollouts: 1
# offline rl input_shapes:
# dataset_dir: ??? # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
data_first_percent: 1.0 observation.image: [3, 84, 84]
is_data_clip: true observation.state: ["${env.state_dim}"]
data_clip_eps: 1e-5 output_shapes:
expectile: 0.9 action: ["${env.action_dim}"]
A_scaling: 3.0
# offline->online # Normalization / Unnormalization
offline_steps: ${offline_steps} input_normalization_modes: null
pretrained_model_path: "" output_normalization_modes:
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" action: min_max
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
balanced_sampling: true
demo_schedule: 0.5
# architecture # Architecture / modeling.
enc_dim: 256 # Neural networks.
num_q: 5 image_encoder_hidden_dim: 32
mlp_dim: 512 state_encoder_hidden_dim: 256
latent_dim: 50 latent_dim: 50
q_ensemble_size: 5
mlp_dim: 512
# Reinforcement learning.
discount: 0.9
delta_timestamps: # Inference.
observation.image: "[i / ${fps} for i in range(6)]" use_mpc: false
observation.state: "[i / ${fps} for i in range(6)]" cem_iterations: 6
action: "[i / ${fps} for i in range(5)]" max_std: 2.0
next.reward: "[i / ${fps} for i in range(5)]" min_std: 0.05
n_gaussian_samples: 512
n_pi_samples: 51
uncertainty_regularizer_coeff: 1.0
n_elites: 50
elite_weighting_temperature: 0.5
gaussian_mean_momentum: 0.1
# Training and loss computation.
max_random_shift_ratio: 0.0476
# Loss coefficients.
reward_coeff: 0.5
expectile_weight: 0.9
value_coeff: 0.1
consistency_coeff: 20.0
advantage_scaling: 3.0
pi_coeff: 0.5
temporal_decay_coeff: 0.5
# Target model.
target_model_momentum: 0.995

View File

@ -67,10 +67,10 @@ def eval_policy(
""" """
set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict. set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict.
""" """
policy.eval()
fps = env.unwrapped.metadata["render_fps"] fps = env.unwrapped.metadata["render_fps"]
if policy is not None:
policy.eval()
device = "cpu" if policy is None else next(policy.parameters()).device device = "cpu" if policy is None else next(policy.parameters()).device
start = time.time() start = time.time()
@ -132,7 +132,7 @@ def eval_policy(
# get the next action for the environment # get the next action for the environment
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation, step=step) action = policy.select_action(observation)
# convert to cpu numpy # convert to cpu numpy
action = postprocess_action(action) action = postprocess_action(action)
@ -386,6 +386,7 @@ def eval(
else: else:
# Note: We need the dataset stats to pass to the policy's normalization modules. # Note: We need the dataset stats to pass to the policy's normalization modules.
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
policy.eval()
info = eval_policy( info = eval_policy(
env, env,

View File

@ -15,6 +15,7 @@ from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device, get_safe_torch_device,
@ -39,12 +40,17 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step() lr_scheduler.step()
if hasattr(policy, "ema") and policy.ema is not None: if hasattr(policy, "ema") and policy.ema is not None:
policy.ema.step(policy.diffusion) policy.ema.step(policy.diffusion)
if isinstance(policy, PolicyWithUpdate):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
info = { info = {
"loss": loss.item(), "loss": loss.item(),
"grad_norm": float(grad_norm), "grad_norm": float(grad_norm),
@ -246,11 +252,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
raise NotImplementedError() raise NotImplementedError()
if job_name is None: if job_name is None:
raise NotImplementedError() raise NotImplementedError()
if cfg.training.online_steps > 0:
assert cfg.eval.batch_size == 1, "eval.batch_size > 1 not supported for online training steps"
init_logging() init_logging()
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
logging.warning("eval.batch_size > 1 not supported for online training steps")
# Check device is available # Check device is available
get_safe_torch_device(cfg.device, log=True) get_safe_torch_device(cfg.device, log=True)
@ -305,7 +312,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_training_steps=cfg.training.offline_steps, num_training_steps=cfg.training.offline_steps,
) )
elif policy.name == "tdmpc": elif policy.name == "tdmpc":
raise NotImplementedError("TD-MPC not implemented yet.") optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None
else:
raise NotImplementedError()
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
@ -361,12 +371,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
) )
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
policy.train()
step = 0 # number of policy update (forward + backward + optim) step = 0 # number of policy update (forward + backward + optim)
is_offline = True is_offline = True
for offline_step in range(cfg.training.offline_steps): for offline_step in range(cfg.training.offline_steps):
if offline_step == 0: if offline_step == 0:
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
policy.train()
batch = next(dl_iter) batch = next(dl_iter)
for key in batch: for key in batch:
@ -414,6 +424,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
if env_step == 0: if env_step == 0:
logging.info("Start online training by interacting with environment") logging.info("Start online training by interacting with environment")
policy.eval()
with torch.no_grad(): with torch.no_grad():
eval_info = eval_policy( eval_info = eval_policy(
rollout_env, rollout_env,
@ -428,11 +439,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
sampler, sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"], hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"], episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.get("demo_schedule", 0.5), pc_online_samples=cfg.training.online_sampling_ratio,
) )
for _ in range(cfg.training.online_steps_between_rollouts):
policy.train() policy.train()
for _ in range(cfg.training.online_steps_between_rollouts):
batch = next(dl_iter) batch = next(dl_iter)
for key in batch: for key in batch:

View File

@ -6,7 +6,7 @@ import pytest
import lerobot import lerobot
from lerobot.common.policies.act.modeling_act import ACTPolicy from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
from tests.utils import require_env from tests.utils import require_env

View File

@ -19,10 +19,6 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
@pytest.mark.parametrize("policy_name", available_policies) @pytest.mark.parametrize("policy_name", available_policies)
def test_get_policy_and_config_classes(policy_name: str): def test_get_policy_and_config_classes(policy_name: str):
"""Check that the correct policy and config classes are returned.""" """Check that the correct policy and config classes are returned."""
if policy_name == "tdmpc":
with pytest.raises(NotImplementedError):
get_policy_and_config_classes(policy_name)
return
policy_cls, config_cls = get_policy_and_config_classes(policy_name) policy_cls, config_cls = get_policy_and_config_classes(policy_name)
assert policy_cls.name == policy_name assert policy_cls.name == policy_name
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation) assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
@ -32,8 +28,7 @@ def test_get_policy_and_config_classes(policy_name: str):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name,policy_name,extra_overrides", "env_name,policy_name,extra_overrides",
[ [
# ("xarm", "tdmpc", ["policy.mpc=true"]), ("xarm", "tdmpc", ["policy.use_mpc=true", "dataset_repo_id=lerobot/xarm_lift_medium"]),
# ("pusht", "tdmpc", ["policy.mpc=false"]),
("pusht", "diffusion", []), ("pusht", "diffusion", []),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]), ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]),
( (
@ -103,7 +98,7 @@ def test_policy(env_name, policy_name, extra_overrides):
batch[key] = batch[key].to(DEVICE, non_blocking=True) batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy # Test updating the policy
policy.forward(batch, step=0) policy.forward(batch)
# reset the policy and environment # reset the policy and environment
policy.reset() policy.reset()
@ -117,7 +112,7 @@ def test_policy(env_name, policy_name, extra_overrides):
# get the next action for the environment # get the next action for the environment
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation, step=0) action = policy.select_action(observation)
# convert action to cpu numpy array # convert action to cpu numpy array
action = postprocess_action(action) action = postprocess_action(action)
@ -129,20 +124,12 @@ def test_policy(env_name, policy_name, extra_overrides):
@pytest.mark.parametrize("policy_name", available_policies) @pytest.mark.parametrize("policy_name", available_policies)
def test_policy_defaults(policy_name: str): def test_policy_defaults(policy_name: str):
"""Check that the policy can be instantiated with defaults.""" """Check that the policy can be instantiated with defaults."""
if policy_name == "tdmpc":
with pytest.raises(NotImplementedError):
get_policy_and_config_classes(policy_name)
return
policy_cls, _ = get_policy_and_config_classes(policy_name) policy_cls, _ = get_policy_and_config_classes(policy_name)
policy_cls() policy_cls()
@pytest.mark.parametrize("policy_name", available_policies) @pytest.mark.parametrize("policy_name", available_policies)
def test_save_and_load_pretrained(policy_name: str): def test_save_and_load_pretrained(policy_name: str):
if policy_name == "tdmpc":
with pytest.raises(NotImplementedError):
get_policy_and_config_classes(policy_name)
return
policy_cls, _ = get_policy_and_config_classes(policy_name) policy_cls, _ = get_policy_and_config_classes(policy_name)
policy: Policy = policy_cls() policy: Policy = policy_cls()
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}" save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"