remove self.model_target and added a target q ensemble only without the need to copy the

entire policy
This commit is contained in:
Michel Aractingi 2024-11-21 15:00:03 +00:00
parent a146544765
commit c41ec08ec1
1 changed files with 27 additions and 33 deletions

View File

@ -39,7 +39,7 @@ from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash
class TDMPC2Policy( class TDMPC2Policy(
@ -84,9 +84,6 @@ class TDMPC2Policy(
config = TDMPC2Config() config = TDMPC2Config()
self.config = config self.config = config
self.model = TDMPC2WorldModel(config) self.model = TDMPC2WorldModel(config)
self.model_target = deepcopy(self.model)
for param in self.model_target.parameters():
param.requires_grad = False
if config.input_normalization_modes is not None: if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(
@ -384,12 +381,12 @@ class TDMPC2Policy(
# Compute various targets with stopgrad. # Compute various targets with stopgrad.
with torch.no_grad(): with torch.no_grad():
# Latent state consistency targets. # Latent state consistency targets.
z_targets = self.model_target.encode(next_observations) z_targets = self.model.encode(next_observations)
# Compute the TD-target from a reward and the next observation # Compute the TD-target from a reward and the next observation
pi = self.model.pi(z_targets)[0] pi = self.model.pi(z_targets)[0]
td_targets = ( td_targets = (
reward reward
+ self.config.discount * self.model_target.Qs(z_targets, pi, return_type="min").squeeze() + self.config.discount * self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
) )
# Compute losses. # Compute losses.
@ -450,10 +447,15 @@ class TDMPC2Policy(
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1. # Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
# We won't need these gradients again so detach. # We won't need these gradients again so detach.
z_preds = z_preds.detach() z_preds = z_preds.detach()
self.model.change_q_grad(mode=False)
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1]) action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
qs = self.model_target.Qs(z_preds[:-1], action_preds, return_type="avg")
self.scale.update(qs[0]) with torch.no_grad():
qs = self.scale(qs) # avoid unnessecary computation of the gradients during policy optimization
# TODO (michel-aractingi): the same logic should be extended when adding task embeddings
qs = self.model.Qs(z_preds[:-1], action_preds, return_type="avg")
self.scale.update(qs[0])
qs = self.scale(qs)
rho = torch.pow(self.config.rho, torch.arange(len(qs), device=qs.device)).unsqueeze(-1) rho = torch.pow(self.config.rho, torch.arange(len(qs), device=qs.device)).unsqueeze(-1)
@ -498,12 +500,8 @@ class TDMPC2Policy(
return info return info
def update(self): def update(self):
"""Update the target model's parameters with an EMA step.""" """Update the target model's using polyak averaging."""
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA self.model.update_target_Q()
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
class TDMPC2WorldModel(nn.Module): class TDMPC2WorldModel(nn.Module):
"""Latent dynamics model used in TD-MPC2.""" """Latent dynamics model used in TD-MPC2."""
@ -586,6 +584,11 @@ class TDMPC2WorldModel(nn.Module):
self.log_std_dif = self.log_std_dif.to(*args, **kwargs) self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
self.bins = self.bins.to(*args, **kwargs) self.bins = self.bins.to(*args, **kwargs)
return self return self
def train(self, mode):
super().train(mode)
self._target_Qs.train(False)
return self
def encode(self, obs: dict[str, Tensor]) -> Tensor: def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation.""" """Encodes an observation into its latent representation."""
@ -622,7 +625,7 @@ class TDMPC2WorldModel(nn.Module):
x = torch.cat([z, a], dim=-1) x = torch.cat([z, a], dim=-1)
return self._dynamics(x) return self._dynamics(x)
def pi(self, z: Tensor, std: float = 0.0) -> Tensor: def pi(self, z: Tensor) -> Tensor:
"""Samples an action from the learned policy. """Samples an action from the learned policy.
The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when
@ -668,6 +671,14 @@ class TDMPC2WorldModel(nn.Module):
Q1, Q2 = two_hot_inv(Q1, self.bins), two_hot_inv(Q2, self.bins) Q1, Q2 = two_hot_inv(Q1, self.bins), two_hot_inv(Q2, self.bins)
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2 return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
def update_target_Q(self):
"""
Soft-update target Q-networks using Polyak averaging.
"""
with torch.no_grad():
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
p_target.data.lerp_(p.data, self.config.target_model_momentum)
class TDMPC2ObservationEncoder(nn.Module): class TDMPC2ObservationEncoder(nn.Module):
"""Encode image and/or state vector observations.""" """Encode image and/or state vector observations."""
@ -777,23 +788,6 @@ def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor:
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
for (n_p_ema, p_ema), (n_p, p) in zip(
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
# Copy BatchNorm parameters, and non-trainable parameters directly.
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
with torch.no_grad():
p_ema.mul_(alpha)
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor. """Helper to temporarily flatten extra dims at the start of the image tensor.