# ruff: noqa: N806 import time from copy import deepcopy import einops import numpy as np import torch import torch.nn as nn import lerobot.common.policies.tdmpc.helper as h from lerobot.common.policies.abstract import AbstractPolicy FIRST_FRAME = 0 class TOLD(nn.Module): """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" def __init__(self, cfg): super().__init__() action_dim = cfg.action_dim self.cfg = cfg self._encoder = h.enc(cfg) self._dynamics = h.dynamics(cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim) self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1) self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim) self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)]) self._V = h.v(cfg) self.apply(h.orthogonal_init) for m in [self._reward, *self._Qs]: m[-1].weight.data.fill_(0) m[-1].bias.data.fill_(0) def track_q_grad(self, enable=True): """Utility function. Enables/disables gradient tracking of Q-networks.""" for m in self._Qs: h.set_requires_grad(m, enable) def track_v_grad(self, enable=True): """Utility function. Enables/disables gradient tracking of Q-networks.""" if hasattr(self, "_V"): h.set_requires_grad(self._V, enable) def encode(self, obs): """Encodes an observation into its latent representation.""" out = self._encoder(obs) if isinstance(obs, dict): # fusion out = torch.stack([v for k, v in out.items()]).mean(dim=0) return out def next(self, z, a): """Predicts next latent state (d) and single-step reward (R).""" x = torch.cat([z, a], dim=-1) return self._dynamics(x), self._reward(x) def next_dynamics(self, z, a): """Predicts next latent state (d).""" x = torch.cat([z, a], dim=-1) return self._dynamics(x) def pi(self, z, std=0): """Samples an action from the learned policy (pi).""" mu = torch.tanh(self._pi(z)) if std > 0: std = torch.ones_like(mu) * std return h.TruncatedNormal(mu, std).sample(clip=0.3) return mu def V(self, z): # noqa: N802 """Predict state value (V).""" return self._V(z) def Q(self, z, a, return_type): # noqa: N802 """Predict state-action value (Q).""" assert return_type in {"min", "avg", "all"} x = torch.cat([z, a], dim=-1) if return_type == "all": return torch.stack([q(x) for q in self._Qs], dim=0) idxs = np.random.choice(self.cfg.num_q, 2, replace=False) Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x) return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2 class TDMPC(AbstractPolicy): """Implementation of TD-MPC learning + inference.""" def __init__(self, cfg, device): super().__init__(None) self.action_dim = cfg.action_dim self.cfg = cfg self.device = torch.device(device) self.std = h.linear_schedule(cfg.std_schedule, 0) self.model = TOLD(cfg).cuda() if torch.cuda.is_available() and device == "cuda" else TOLD(cfg) self.model_target = deepcopy(self.model) self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr) # self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) self.model.eval() self.model_target.eval() self.batch_size = cfg.batch_size self.register_buffer("step", torch.zeros(1)) def state_dict(self): """Retrieve state dict of TOLD model, including slow-moving target network.""" return { "model": self.model.state_dict(), "model_target": self.model_target.state_dict(), } def save(self, fp): """Save state dict of TOLD model to filepath.""" torch.save(self.state_dict(), fp) def load(self, fp): """Load a saved state dict from filepath into current agent.""" d = torch.load(fp) self.model.load_state_dict(d["model"]) self.model_target.load_state_dict(d["model_target"]) @torch.no_grad() def select_actions(self, observation, step_count): if observation["image"].shape[0] != 1: raise NotImplementedError("Batch size > 1 not handled") t0 = step_count.item() == 0 obs = { # TODO(rcadene): remove contiguous hack... "rgb": observation["image"].contiguous(), "state": observation["state"].contiguous(), } # Note: unsqueeze needed because `act` still uses non-batch logic. action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0) return action @torch.no_grad() def act(self, obs, t0=False, step=None): """Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag.""" obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach() z = self.model.encode(obs) if self.cfg.mpc: a = self.plan(z, t0=t0, step=step) else: a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0) return a @torch.no_grad() def estimate_value(self, z, actions, horizon): """Estimate value of a trajectory starting at latent state z and executing given actions.""" G, discount = 0, 1 for t in range(horizon): if self.cfg.uncertainty_cost > 0: G -= ( discount * self.cfg.uncertainty_cost * self.model.Q(z, actions[t], return_type="all").std(dim=0) ) z, reward = self.model.next(z, actions[t]) G += discount * reward discount *= self.cfg.discount pi = self.model.pi(z, self.cfg.min_std) G += discount * self.model.Q(z, pi, return_type="min") if self.cfg.uncertainty_cost > 0: G -= discount * self.cfg.uncertainty_cost * self.model.Q(z, pi, return_type="all").std(dim=0) return G @torch.no_grad() def plan(self, z, step=None, t0=True): """ Plan next action using TD-MPC inference. z: latent state. step: current time step. determines e.g. planning horizon. t0: whether current step is the first step of an episode. """ # during eval: eval_mode: uniform sampling and action noise is disabled during evaluation. assert step is not None # Seed steps if step < self.cfg.seed_steps and self.model.training: return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1) # Sample policy trajectories horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))) num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples) if num_pi_trajs > 0: pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device) _z = z.repeat(num_pi_trajs, 1) for t in range(horizon): pi_actions[t] = self.model.pi(_z, self.cfg.min_std) _z = self.model.next_dynamics(_z, pi_actions[t]) # Initialize state and parameters z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1) mean = torch.zeros(horizon, self.action_dim, device=self.device) std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device) if not t0 and hasattr(self, "_prev_mean"): mean[:-1] = self._prev_mean[1:] # Iterate CEM for _ in range(self.cfg.iterations): actions = torch.clamp( mean.unsqueeze(1) + std.unsqueeze(1) * torch.randn(horizon, self.cfg.num_samples, self.action_dim, device=std.device), -1, 1, ) if num_pi_trajs > 0: actions = torch.cat([actions, pi_actions], dim=1) # Compute elite actions value = self.estimate_value(z, actions, horizon).nan_to_num_(0) elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] # Update parameters max_value = elite_value.max(0)[0] score = torch.exp(self.cfg.temperature * (elite_value - max_value)) score /= score.sum(0) _mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9) _std = torch.sqrt( torch.sum( score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2, dim=1, ) / (score.sum(0) + 1e-9) ) _std = _std.clamp_(self.std, self.cfg.max_std) mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std # Outputs # TODO(rcadene): remove numpy with # # Convert score tensor to probabilities using softmax # probabilities = torch.softmax(score, dim=0) # # Generate a random sample index based on the probabilities # sample_index = torch.multinomial(probabilities, 1).item() score = score.squeeze(1).cpu().numpy() actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)] self._prev_mean = mean mean, std = actions[0], _std[0] a = mean if self.model.training: a += std * torch.randn(self.action_dim, device=std.device) return torch.clamp(a, -1, 1) def update_pi(self, zs, acts=None): """Update policy using a sequence of latent states.""" self.pi_optim.zero_grad(set_to_none=True) self.model.track_q_grad(False) self.model.track_v_grad(False) info = {} # Advantage Weighted Regression assert acts is not None vs = self.model.V(zs) qs = self.model_target.Q(zs, acts, return_type="min") adv = qs - vs exp_a = torch.exp(adv * self.cfg.A_scaling) exp_a = torch.clamp(exp_a, max=100.0) log_probs = h.gaussian_logprob(self.model.pi(zs) - acts, 0) rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) pi_loss = -((exp_a * log_probs).mean(dim=(1, 2)) * rho).mean() info["adv"] = adv[0] pi_loss.backward() torch.nn.utils.clip_grad_norm_( self.model._pi.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False, ) self.pi_optim.step() self.model.track_q_grad(True) self.model.track_v_grad(True) info["pi_loss"] = pi_loss.item() return pi_loss.item(), info @torch.no_grad() def _td_target(self, next_z, reward, mask): """Compute the TD-target from a reward and the observation at the following time step.""" next_v = self.model.V(next_z) td_target = reward + self.cfg.discount * mask * next_v return td_target def update(self, replay_buffer, step, demo_buffer=None): """Main update function. Corresponds to one iteration of the model learning.""" start_time = time.time() num_slices = self.cfg.batch_size batch_size = self.cfg.horizon * num_slices if demo_buffer is None: demo_batch_size = 0 else: # Update oversampling ratio demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step) demo_num_slices = int(demo_pc_batch * self.batch_size) demo_batch_size = self.cfg.horizon * demo_num_slices batch_size -= demo_batch_size num_slices -= demo_num_slices replay_buffer._sampler.num_slices = num_slices demo_buffer._sampler.num_slices = demo_num_slices assert demo_batch_size % self.cfg.horizon == 0 assert demo_batch_size % demo_num_slices == 0 assert batch_size % self.cfg.horizon == 0 assert batch_size % num_slices == 0 # Sample from interaction dataset def process_batch(batch, horizon, num_slices): # trajectory t = 256, horizon h = 5 # (t h) ... -> h t ... batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() obs = { "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True), "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True), } action = batch["action"].to(self.device, non_blocking=True) next_obses = { "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True), "state": batch["next", "observation", "state"].to(self.device, non_blocking=True), } reward = batch["next", "reward"].to(self.device, non_blocking=True) idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True) weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True) # TODO(rcadene): rearrange directly in offline dataset if reward.ndim == 2: reward = einops.rearrange(reward, "h t -> h t 1") assert reward.ndim == 3 assert reward.shape == (horizon, num_slices, 1) # We dont use `batch["next", "done"]` since it only indicates the end of an # episode, but not the end of the trajectory of an episode. # Neither does `batch["next", "terminated"]` done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) return obs, action, next_obses, reward, mask, done, idxs, weights batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() obs, action, next_obses, reward, mask, done, idxs, weights = process_batch( batch, self.cfg.horizon, num_slices ) # Sample from demonstration dataset if demo_batch_size > 0: demo_batch = demo_buffer.sample(demo_batch_size) ( demo_obs, demo_action, demo_next_obses, demo_reward, demo_mask, demo_done, demo_idxs, demo_weights, ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices) if isinstance(obs, dict): obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs} next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses} else: obs = torch.cat([obs, demo_obs]) next_obses = torch.cat([next_obses, demo_next_obses], dim=1) action = torch.cat([action, demo_action], dim=1) reward = torch.cat([reward, demo_reward], dim=1) mask = torch.cat([mask, demo_mask], dim=1) done = torch.cat([done, demo_done], dim=1) idxs = torch.cat([idxs, demo_idxs]) weights = torch.cat([weights, demo_weights]) # Apply augmentations aug_tf = h.aug(self.cfg) obs = aug_tf(obs) for k in next_obses: next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...") next_obses = aug_tf(next_obses) for k in next_obses: next_obses[k] = einops.rearrange( next_obses[k], "(h t) ... -> h t ...", h=self.cfg.horizon, t=self.cfg.batch_size, ) horizon = self.cfg.horizon loss_mask = torch.ones_like(mask, device=self.device) for t in range(1, horizon): loss_mask[t] = loss_mask[t - 1] * (~done[t - 1]) self.optim.zero_grad(set_to_none=True) self.std = h.linear_schedule(self.cfg.std_schedule, step) self.model.train() data_s = time.time() - start_time # Compute targets with torch.no_grad(): next_z = self.model.encode(next_obses) z_targets = self.model_target.encode(next_obses) td_targets = self._td_target(next_z, reward, mask) # Latent rollout zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device) reward_preds = torch.empty_like(reward, device=self.device) assert reward.shape[0] == horizon z = self.model.encode(obs) zs[0] = z value_info = {"Q": 0.0, "V": 0.0} for t in range(horizon): z, reward_pred = self.model.next(z, action[t]) zs[t + 1] = z reward_preds[t] = reward_pred with torch.no_grad(): v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min") # Predictions qs = self.model.Q(zs[:-1], action, return_type="all") value_info["Q"] = qs.mean().item() v = self.model.V(zs[:-1]) value_info["V"] = v.mean().item() # Losses rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1) consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum( dim=0 ) reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0) q_value_loss, priority_loss = 0, 0 for q in range(self.cfg.num_q): q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0) priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0) expectile = h.linear_schedule(self.cfg.expectile, step) v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0) total_loss = ( self.cfg.consistency_coef * consistency_loss + self.cfg.reward_coef * reward_loss + self.cfg.value_coef * q_value_loss + self.cfg.value_coef * v_value_loss ) weighted_loss = (total_loss.squeeze(1) * weights).mean() weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon)) has_nan = torch.isnan(weighted_loss).item() if has_nan: print(f"weighted_loss has nan: {total_loss=} {weights=}") else: weighted_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False ) self.optim.step() if self.cfg.per: # Update priorities priorities = priority_loss.clamp(max=1e4).detach() has_nan = torch.isnan(priorities).any().item() if has_nan: print(f"priorities has nan: {priorities=}") else: replay_buffer.update_priority( idxs[:num_slices], priorities[:num_slices], ) if demo_batch_size > 0: demo_buffer.update_priority(demo_idxs, priorities[num_slices:]) # Update policy + target network _, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action) if step % self.cfg.update_freq == 0: h.ema(self.model._encoder, self.model_target._encoder, self.cfg.tau) h.ema(self.model._Qs, self.model_target._Qs, self.cfg.tau) self.model.eval() info = { "consistency_loss": float(consistency_loss.mean().item()), "reward_loss": float(reward_loss.mean().item()), "Q_value_loss": float(q_value_loss.mean().item()), "V_value_loss": float(v_value_loss.mean().item()), "sum_loss": float(total_loss.mean().item()), "loss": float(weighted_loss.mean().item()), "grad_norm": float(grad_norm), "lr": self.cfg.lr, "data_s": data_s, "update_s": time.time() - start_time, } info["demo_batch_size"] = demo_batch_size info["expectile"] = expectile info.update(value_info) info.update(pi_update_info) self.step[0] = step return info