lerobot/lerobot/common/policies/tdmpc/helper.py

577 lines
18 KiB
Python

import os
import pickle
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal
DEFAULT_ACT_FN = nn.Mish()
def __REDUCE__(b): # noqa: N802, N807
return "mean" if b else "none"
def l1(pred, target, reduce=False):
"""Computes the L1-loss between predictions and targets."""
return F.l1_loss(pred, target, reduction=__REDUCE__(reduce))
def mse(pred, target, reduce=False):
"""Computes the MSE loss between predictions and targets."""
return F.mse_loss(pred, target, reduction=__REDUCE__(reduce))
def l2_expectile(diff, expectile=0.7, reduce=False):
weight = torch.where(diff > 0, expectile, (1 - expectile))
loss = weight * (diff**2)
reduction = __REDUCE__(reduce)
if reduction == "mean":
return torch.mean(loss)
elif reduction == "sum":
return torch.sum(loss)
return loss
def _get_out_shape(in_shape, layers):
"""Utility function. Returns the output shape of a network for a given input shape."""
x = torch.randn(*in_shape).unsqueeze(0)
return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape
def gaussian_logprob(eps, log_std):
"""Compute Gaussian log probability."""
residual = (-0.5 * eps.pow(2) - log_std).sum(-1, keepdim=True)
return residual - 0.5 * np.log(2 * np.pi) * eps.size(-1)
def squash(mu, pi, log_pi):
"""Apply squashing function."""
mu = torch.tanh(mu)
pi = torch.tanh(pi)
log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
return mu, pi, log_pi
def orthogonal_init(m):
"""Orthogonal layer initialization."""
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
gain = nn.init.calculate_gain("relu")
nn.init.orthogonal_(m.weight.data, gain)
if m.bias is not None:
nn.init.zeros_(m.bias)
def ema(m, m_target, tau):
"""Update slow-moving average of online network (target network) at rate tau."""
with torch.no_grad():
# TODO(rcadene, aliberts): issue with strict=False
# for p, p_target in zip(m.parameters(), m_target.parameters(), strict=False):
# p_target.data.lerp_(p.data, tau)
m_params_iter = iter(m.parameters())
m_target_params_iter = iter(m_target.parameters())
while True:
try:
p = next(m_params_iter)
p_target = next(m_target_params_iter)
p_target.data.lerp_(p.data, tau)
except StopIteration:
# If any iterator is exhausted, exit the loop
break
def set_requires_grad(net, value):
"""Enable/disable gradients for a given (sub)network."""
for param in net.parameters():
param.requires_grad_(value)
class TruncatedNormal(pyd.Normal):
"""Utility class implementing the truncated normal distribution."""
default_sample_shape = torch.Size()
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
super().__init__(loc, scale, validate_args=False)
self.low = low
self.high = high
self.eps = eps
def _clamp(self, x):
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
x = x - x.detach() + clamped_x.detach()
return x
def sample(self, clip=None, sample_shape=default_sample_shape):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
eps *= self.scale
if clip is not None:
eps = torch.clamp(eps, -clip, clip)
x = self.loc + eps
return self._clamp(x)
class NormalizeImg(nn.Module):
"""Normalizes pixel observations to [0,1) range."""
def __init__(self):
super().__init__()
def forward(self, x):
return x.div(255.0)
class Flatten(nn.Module):
"""Flattens its input to a (batched) vector."""
def __init__(self):
super().__init__()
def forward(self, x):
return x.view(x.size(0), -1)
def enc(cfg):
obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size),
"state": (cfg.state_dim,),
}
"""Returns a TOLD encoder."""
pixels_enc_layers, state_enc_layers = None, None
if cfg.modality in {"pixels", "all"}:
C = int(3 * cfg.frame_stack) # noqa: N806
pixels_enc_layers = [
NormalizeImg(),
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
nn.ReLU(),
]
out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), pixels_enc_layers)
pixels_enc_layers.extend(
[
Flatten(),
nn.Linear(np.prod(out_shape), cfg.latent_dim),
nn.LayerNorm(cfg.latent_dim),
nn.Sigmoid(),
]
)
if cfg.modality == "pixels":
return ConvExt(nn.Sequential(*pixels_enc_layers))
if cfg.modality in {"state", "all"}:
state_dim = obs_shape[0] if cfg.modality == "state" else obs_shape["state"][0]
state_enc_layers = [
nn.Linear(state_dim, cfg.enc_dim),
nn.ELU(),
nn.Linear(cfg.enc_dim, cfg.latent_dim),
nn.LayerNorm(cfg.latent_dim),
nn.Sigmoid(),
]
if cfg.modality == "state":
return nn.Sequential(*state_enc_layers)
else:
raise NotImplementedError
encoders = {}
for k in obs_shape:
if k == "state":
encoders[k] = nn.Sequential(*state_enc_layers)
elif k.endswith("rgb"):
encoders[k] = ConvExt(nn.Sequential(*pixels_enc_layers))
else:
raise NotImplementedError
return Multiplexer(nn.ModuleDict(encoders))
def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
"""Returns an MLP."""
if isinstance(mlp_dim, int):
mlp_dim = [mlp_dim, mlp_dim]
return nn.Sequential(
nn.Linear(in_dim, mlp_dim[0]),
nn.LayerNorm(mlp_dim[0]),
act_fn,
nn.Linear(mlp_dim[0], mlp_dim[1]),
nn.LayerNorm(mlp_dim[1]),
act_fn,
nn.Linear(mlp_dim[1], out_dim),
)
def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
"""Returns a dynamics network."""
return nn.Sequential(
mlp(in_dim, mlp_dim, out_dim, act_fn),
nn.LayerNorm(out_dim),
nn.Sigmoid(),
)
def q(cfg):
action_dim = cfg.action_dim
"""Returns a Q-function that uses Layer Normalization."""
return nn.Sequential(
nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim),
nn.LayerNorm(cfg.mlp_dim),
nn.Tanh(),
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
nn.ELU(),
nn.Linear(cfg.mlp_dim, 1),
)
def v(cfg):
"""Returns a state value function that uses Layer Normalization."""
return nn.Sequential(
nn.Linear(cfg.latent_dim, cfg.mlp_dim),
nn.LayerNorm(cfg.mlp_dim),
nn.Tanh(),
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
nn.ELU(),
nn.Linear(cfg.mlp_dim, 1),
)
def aug(cfg):
obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size),
"state": (4,),
}
"""Multiplex augmentation"""
if cfg.modality == "state":
return nn.Identity()
elif cfg.modality == "pixels":
return RandomShiftsAug(cfg)
else:
augs = {}
for k in obs_shape:
if k == "state":
augs[k] = nn.Identity()
elif k.endswith("rgb"):
augs[k] = RandomShiftsAug(cfg)
else:
raise NotImplementedError
return Multiplexer(nn.ModuleDict(augs))
class ConvExt(nn.Module):
"""Auxiliary conv net accommodating high-dim input"""
def __init__(self, conv):
super().__init__()
self.conv = conv
def forward(self, x):
if x.ndim > 4:
batch_shape = x.shape[:-3]
out = self.conv(x.view(-1, *x.shape[-3:]))
out = out.view(*batch_shape, *out.shape[1:])
else:
out = self.conv(x)
return out
class Multiplexer(nn.Module):
"""Model multiplexer"""
def __init__(self, choices):
super().__init__()
self.choices = choices
def forward(self, x, key=None):
if isinstance(x, dict):
if key is not None:
return self.choices[key](x)
return {k: self.choices[k](_x) for k, _x in x.items()}
return self.choices(x)
class RandomShiftsAug(nn.Module):
"""
Random shift image augmentation.
Adapted from https://github.com/facebookresearch/drqv2
"""
def __init__(self, cfg):
super().__init__()
assert cfg.modality in {"pixels", "all"}
self.pad = int(cfg.img_size / 21)
def forward(self, x):
n, c, h, w = x.size()
assert h == w
padding = tuple([self.pad] * 4)
x = F.pad(x, padding, "replicate")
eps = 1.0 / (h + 2 * self.pad)
arange = torch.linspace(
-1.0 + eps,
1.0 - eps,
h + 2 * self.pad,
device=x.device,
dtype=torch.float32,
)[:h]
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
shift = torch.randint(
0,
2 * self.pad + 1,
size=(n, 1, 1, 2),
device=x.device,
dtype=torch.float32,
)
shift *= 2.0 / (h + 2 * self.pad)
grid = base_grid + shift
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
# TODO(aliberts): remove class
# class Episode:
# """Storage object for a single episode."""
# def __init__(self, cfg, init_obs):
# action_dim = cfg.action_dim
# self.cfg = cfg
# self.device = torch.device(cfg.buffer_device)
# if cfg.modality in {"pixels", "state"}:
# dtype = torch.float32 if cfg.modality == "state" else torch.uint8
# self.obses = torch.empty(
# (cfg.episode_length + 1, *init_obs.shape),
# dtype=dtype,
# device=self.device,
# )
# self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
# elif cfg.modality == "all":
# self.obses = {}
# for k, v in init_obs.items():
# assert k in {"rgb", "state"}
# dtype = torch.float32 if k == "state" else torch.uint8
# self.obses[k] = torch.empty(
# (cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device
# )
# self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
# else:
# raise ValueError
# self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device)
# self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
# self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device)
# self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
# self.cumulative_reward = 0
# self.done = False
# self.success = False
# self._idx = 0
# def __len__(self):
# return self._idx
# @classmethod
# def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None):
# """Constructs an episode from a trajectory."""
# if cfg.modality in {"pixels", "state"}:
# episode = cls(cfg, obses[0])
# episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device)
# elif cfg.modality == "all":
# episode = cls(cfg, {k: v[0] for k, v in obses.items()})
# for k in obses:
# episode.obses[k][1:] = torch.tensor(
# obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
# )
# else:
# raise NotImplementedError
# episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device)
# episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device)
# episode.dones = (
# torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
# if dones is not None
# else torch.zeros_like(episode.dones)
# )
# episode.masks = (
# torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device)
# if masks is not None
# else torch.ones_like(episode.masks)
# )
# episode.cumulative_reward = torch.sum(episode.rewards)
# episode.done = True
# episode._idx = cfg.episode_length
# return episode
# @property
# def first(self):
# return len(self) == 0
# def __add__(self, transition):
# self.add(*transition)
# return self
# def add(self, obs, action, reward, done, mask=1.0, success=False):
# """Add a transition into the episode."""
# if isinstance(obs, dict):
# for k, v in obs.items():
# self.obses[k][self._idx + 1] = torch.tensor(
# v, dtype=self.obses[k].dtype, device=self.obses[k].device
# )
# else:
# self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device)
# self.actions[self._idx] = action
# self.rewards[self._idx] = reward
# self.dones[self._idx] = done
# self.masks[self._idx] = mask
# self.cumulative_reward += reward
# self.done = done
# self.success = self.success or success
# self._idx += 1
def get_dataset_dict(cfg, env, return_reward_normalizer=False):
"""Construct a dataset for env"""
required_keys = [
"observations",
"next_observations",
"actions",
"rewards",
"dones",
"masks",
]
if cfg.task.startswith("xarm"):
dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl")
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
for k in required_keys:
if k not in dataset_dict and k[:-1] in dataset_dict:
dataset_dict[k] = dataset_dict.pop(k[:-1])
elif cfg.task.startswith("legged"):
dataset_path = os.path.join(cfg.dataset_dir, "buffer.pkl")
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
dataset_dict["actions"] /= env.unwrapped.clip_actions
print(f"clip_actions={env.unwrapped.clip_actions}")
else:
import d4rl
dataset_dict = d4rl.qlearning_dataset(env)
dones = np.full_like(dataset_dict["rewards"], False, dtype=bool)
for i in range(len(dones) - 1):
if (
np.linalg.norm(dataset_dict["observations"][i + 1] - dataset_dict["next_observations"][i])
> 1e-6
or dataset_dict["terminals"][i] == 1.0
):
dones[i] = True
dones[-1] = True
dataset_dict["masks"] = 1.0 - dataset_dict["terminals"]
del dataset_dict["terminals"]
for k, v in dataset_dict.items():
dataset_dict[k] = v.astype(np.float32)
dataset_dict["dones"] = dones
if cfg.is_data_clip:
lim = 1 - cfg.data_clip_eps
dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim)
reward_normalizer = get_reward_normalizer(cfg, dataset_dict)
dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"])
for key in required_keys:
assert key in dataset_dict, f"Missing `{key}` in dataset."
if return_reward_normalizer:
return dataset_dict, reward_normalizer
return dataset_dict
def get_trajectory_boundaries_and_returns(dataset):
"""
Split dataset into trajectories and compute returns
"""
episode_starts = [0]
episode_ends = []
episode_return = 0
episode_returns = []
n_transitions = len(dataset["rewards"])
for i in range(n_transitions):
episode_return += dataset["rewards"][i]
if dataset["dones"][i]:
episode_returns.append(episode_return)
episode_ends.append(i + 1)
if i + 1 < n_transitions:
episode_starts.append(i + 1)
episode_return = 0.0
return episode_starts, episode_ends, episode_returns
def normalize_returns(dataset, scaling=1000):
"""
Normalize returns in the dataset
"""
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
dataset["rewards"] /= np.max(episode_returns) - np.min(episode_returns)
dataset["rewards"] *= scaling
return dataset
def get_reward_normalizer(cfg, dataset):
"""
Get a reward normalizer for the dataset
"""
if cfg.task.startswith("xarm"):
return lambda x: x
elif "maze" in cfg.task:
return lambda x: x - 1.0
elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]:
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
return lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0
elif hasattr(cfg, "reward_scale"):
return lambda x: x * cfg.reward_scale
return lambda x: x
def linear_schedule(schdl, step):
"""
Outputs values following a linear decay schedule.
Adapted from https://github.com/facebookresearch/drqv2
"""
try:
return float(schdl)
except ValueError:
match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl)
if match:
init, final, start, end = (float(g) for g in match.groups())
mix = np.clip((step - start) / (end - start), 0.0, 1.0)
return (1.0 - mix) * init + mix * final
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
if match:
init, final, duration = (float(g) for g in match.groups())
mix = np.clip(step / duration, 0.0, 1.0)
return (1.0 - mix) * init + mix * final
raise NotImplementedError(schdl)