From c41ec08ec1a6be4b88a60ded7b845ff1cc8f2a54 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Thu, 21 Nov 2024 15:00:03 +0000 Subject: [PATCH] remove self.model_target and added a target q ensemble only without the need to copy the entire policy --- .../common/policies/tdmpc2/modeling_tdmpc2.py | 60 +++++++++---------- 1 file changed, 27 insertions(+), 33 deletions(-) diff --git a/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py b/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py index 1bd29369..5a454f14 100644 --- a/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py +++ b/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py @@ -39,7 +39,7 @@ from torch import Tensor from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config from lerobot.common.policies.utils import get_device_from_parameters, populate_queues -from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv +from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash class TDMPC2Policy( @@ -84,9 +84,6 @@ class TDMPC2Policy( config = TDMPC2Config() self.config = config self.model = TDMPC2WorldModel(config) - self.model_target = deepcopy(self.model) - for param in self.model_target.parameters(): - param.requires_grad = False if config.input_normalization_modes is not None: self.normalize_inputs = Normalize( @@ -384,12 +381,12 @@ class TDMPC2Policy( # Compute various targets with stopgrad. with torch.no_grad(): # Latent state consistency targets. - z_targets = self.model_target.encode(next_observations) + z_targets = self.model.encode(next_observations) # Compute the TD-target from a reward and the next observation pi = self.model.pi(z_targets)[0] td_targets = ( reward - + self.config.discount * self.model_target.Qs(z_targets, pi, return_type="min").squeeze() + + self.config.discount * self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze() ) # Compute losses. @@ -450,10 +447,15 @@ class TDMPC2Policy( # Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1. # We won't need these gradients again so detach. z_preds = z_preds.detach() + self.model.change_q_grad(mode=False) action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1]) - qs = self.model_target.Qs(z_preds[:-1], action_preds, return_type="avg") - self.scale.update(qs[0]) - qs = self.scale(qs) + + with torch.no_grad(): + # avoid unnessecary computation of the gradients during policy optimization + # TODO (michel-aractingi): the same logic should be extended when adding task embeddings + qs = self.model.Qs(z_preds[:-1], action_preds, return_type="avg") + self.scale.update(qs[0]) + qs = self.scale(qs) rho = torch.pow(self.config.rho, torch.arange(len(qs), device=qs.device)).unsqueeze(-1) @@ -498,12 +500,8 @@ class TDMPC2Policy( return info def update(self): - """Update the target model's parameters with an EMA step.""" - # Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA - # update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code - # we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995) - update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum) - + """Update the target model's using polyak averaging.""" + self.model.update_target_Q() class TDMPC2WorldModel(nn.Module): """Latent dynamics model used in TD-MPC2.""" @@ -586,6 +584,11 @@ class TDMPC2WorldModel(nn.Module): self.log_std_dif = self.log_std_dif.to(*args, **kwargs) self.bins = self.bins.to(*args, **kwargs) return self + + def train(self, mode): + super().train(mode) + self._target_Qs.train(False) + return self def encode(self, obs: dict[str, Tensor]) -> Tensor: """Encodes an observation into its latent representation.""" @@ -622,7 +625,7 @@ class TDMPC2WorldModel(nn.Module): x = torch.cat([z, a], dim=-1) return self._dynamics(x) - def pi(self, z: Tensor, std: float = 0.0) -> Tensor: + def pi(self, z: Tensor) -> Tensor: """Samples an action from the learned policy. The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when @@ -668,6 +671,14 @@ class TDMPC2WorldModel(nn.Module): Q1, Q2 = two_hot_inv(Q1, self.bins), two_hot_inv(Q2, self.bins) return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2 + def update_target_Q(self): + """ + Soft-update target Q-networks using Polyak averaging. + """ + with torch.no_grad(): + for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()): + p_target.data.lerp_(p.data, self.config.target_model_momentum) + class TDMPC2ObservationEncoder(nn.Module): """Encode image and/or state vector observations.""" @@ -777,23 +788,6 @@ def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor: return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) -def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float): - """Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param.""" - for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True): - for (n_p_ema, p_ema), (n_p, p) in zip( - ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True - ): - assert n_p_ema == n_p, "Parameter names don't match for EMA model update" - if isinstance(p, dict): - raise RuntimeError("Dict parameter not supported") - if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad: - # Copy BatchNorm parameters, and non-trainable parameters directly. - p_ema.copy_(p.to(dtype=p_ema.dtype).data) - with torch.no_grad(): - p_ema.mul_(alpha) - p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha) - - def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: """Helper to temporarily flatten extra dims at the start of the image tensor.