remove self.model_target and added a target q ensemble only without the need to copy the

entire policy
This commit is contained in:
Michel Aractingi 2024-11-21 15:00:03 +00:00
parent a146544765
commit c41ec08ec1
1 changed files with 27 additions and 33 deletions

View File

@ -39,7 +39,7 @@ from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash
class TDMPC2Policy(
@ -84,9 +84,6 @@ class TDMPC2Policy(
config = TDMPC2Config()
self.config = config
self.model = TDMPC2WorldModel(config)
self.model_target = deepcopy(self.model)
for param in self.model_target.parameters():
param.requires_grad = False
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
@ -384,12 +381,12 @@ class TDMPC2Policy(
# Compute various targets with stopgrad.
with torch.no_grad():
# Latent state consistency targets.
z_targets = self.model_target.encode(next_observations)
z_targets = self.model.encode(next_observations)
# Compute the TD-target from a reward and the next observation
pi = self.model.pi(z_targets)[0]
td_targets = (
reward
+ self.config.discount * self.model_target.Qs(z_targets, pi, return_type="min").squeeze()
+ self.config.discount * self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
)
# Compute losses.
@ -450,8 +447,13 @@ class TDMPC2Policy(
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
# We won't need these gradients again so detach.
z_preds = z_preds.detach()
self.model.change_q_grad(mode=False)
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
qs = self.model_target.Qs(z_preds[:-1], action_preds, return_type="avg")
with torch.no_grad():
# avoid unnessecary computation of the gradients during policy optimization
# TODO (michel-aractingi): the same logic should be extended when adding task embeddings
qs = self.model.Qs(z_preds[:-1], action_preds, return_type="avg")
self.scale.update(qs[0])
qs = self.scale(qs)
@ -498,12 +500,8 @@ class TDMPC2Policy(
return info
def update(self):
"""Update the target model's parameters with an EMA step."""
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
"""Update the target model's using polyak averaging."""
self.model.update_target_Q()
class TDMPC2WorldModel(nn.Module):
"""Latent dynamics model used in TD-MPC2."""
@ -587,6 +585,11 @@ class TDMPC2WorldModel(nn.Module):
self.bins = self.bins.to(*args, **kwargs)
return self
def train(self, mode):
super().train(mode)
self._target_Qs.train(False)
return self
def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation."""
return self._encoder(obs)
@ -622,7 +625,7 @@ class TDMPC2WorldModel(nn.Module):
x = torch.cat([z, a], dim=-1)
return self._dynamics(x)
def pi(self, z: Tensor, std: float = 0.0) -> Tensor:
def pi(self, z: Tensor) -> Tensor:
"""Samples an action from the learned policy.
The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when
@ -668,6 +671,14 @@ class TDMPC2WorldModel(nn.Module):
Q1, Q2 = two_hot_inv(Q1, self.bins), two_hot_inv(Q2, self.bins)
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
def update_target_Q(self):
"""
Soft-update target Q-networks using Polyak averaging.
"""
with torch.no_grad():
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
p_target.data.lerp_(p.data, self.config.target_model_momentum)
class TDMPC2ObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
@ -777,23 +788,6 @@ def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor:
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
for (n_p_ema, p_ema), (n_p, p) in zip(
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
# Copy BatchNorm parameters, and non-trainable parameters directly.
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
with torch.no_grad():
p_ema.mul_(alpha)
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.