remove self.model_target and added a target q ensemble only without the need to copy the
entire policy
This commit is contained in:
parent
a146544765
commit
c41ec08ec1
|
@ -39,7 +39,7 @@ from torch import Tensor
|
|||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv
|
||||
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash
|
||||
|
||||
|
||||
class TDMPC2Policy(
|
||||
|
@ -84,9 +84,6 @@ class TDMPC2Policy(
|
|||
config = TDMPC2Config()
|
||||
self.config = config
|
||||
self.model = TDMPC2WorldModel(config)
|
||||
self.model_target = deepcopy(self.model)
|
||||
for param in self.model_target.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
self.normalize_inputs = Normalize(
|
||||
|
@ -384,12 +381,12 @@ class TDMPC2Policy(
|
|||
# Compute various targets with stopgrad.
|
||||
with torch.no_grad():
|
||||
# Latent state consistency targets.
|
||||
z_targets = self.model_target.encode(next_observations)
|
||||
z_targets = self.model.encode(next_observations)
|
||||
# Compute the TD-target from a reward and the next observation
|
||||
pi = self.model.pi(z_targets)[0]
|
||||
td_targets = (
|
||||
reward
|
||||
+ self.config.discount * self.model_target.Qs(z_targets, pi, return_type="min").squeeze()
|
||||
+ self.config.discount * self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
|
||||
)
|
||||
|
||||
# Compute losses.
|
||||
|
@ -450,8 +447,13 @@ class TDMPC2Policy(
|
|||
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
||||
# We won't need these gradients again so detach.
|
||||
z_preds = z_preds.detach()
|
||||
self.model.change_q_grad(mode=False)
|
||||
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
|
||||
qs = self.model_target.Qs(z_preds[:-1], action_preds, return_type="avg")
|
||||
|
||||
with torch.no_grad():
|
||||
# avoid unnessecary computation of the gradients during policy optimization
|
||||
# TODO (michel-aractingi): the same logic should be extended when adding task embeddings
|
||||
qs = self.model.Qs(z_preds[:-1], action_preds, return_type="avg")
|
||||
self.scale.update(qs[0])
|
||||
qs = self.scale(qs)
|
||||
|
||||
|
@ -498,12 +500,8 @@ class TDMPC2Policy(
|
|||
return info
|
||||
|
||||
def update(self):
|
||||
"""Update the target model's parameters with an EMA step."""
|
||||
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
|
||||
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
|
||||
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
|
||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||
|
||||
"""Update the target model's using polyak averaging."""
|
||||
self.model.update_target_Q()
|
||||
|
||||
class TDMPC2WorldModel(nn.Module):
|
||||
"""Latent dynamics model used in TD-MPC2."""
|
||||
|
@ -587,6 +585,11 @@ class TDMPC2WorldModel(nn.Module):
|
|||
self.bins = self.bins.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def train(self, mode):
|
||||
super().train(mode)
|
||||
self._target_Qs.train(False)
|
||||
return self
|
||||
|
||||
def encode(self, obs: dict[str, Tensor]) -> Tensor:
|
||||
"""Encodes an observation into its latent representation."""
|
||||
return self._encoder(obs)
|
||||
|
@ -622,7 +625,7 @@ class TDMPC2WorldModel(nn.Module):
|
|||
x = torch.cat([z, a], dim=-1)
|
||||
return self._dynamics(x)
|
||||
|
||||
def pi(self, z: Tensor, std: float = 0.0) -> Tensor:
|
||||
def pi(self, z: Tensor) -> Tensor:
|
||||
"""Samples an action from the learned policy.
|
||||
|
||||
The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when
|
||||
|
@ -668,6 +671,14 @@ class TDMPC2WorldModel(nn.Module):
|
|||
Q1, Q2 = two_hot_inv(Q1, self.bins), two_hot_inv(Q2, self.bins)
|
||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||
|
||||
def update_target_Q(self):
|
||||
"""
|
||||
Soft-update target Q-networks using Polyak averaging.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
|
||||
p_target.data.lerp_(p.data, self.config.target_model_momentum)
|
||||
|
||||
|
||||
class TDMPC2ObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
@ -777,23 +788,6 @@ def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor:
|
|||
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
|
||||
|
||||
|
||||
def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
|
||||
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
|
||||
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
|
||||
for (n_p_ema, p_ema), (n_p, p) in zip(
|
||||
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
|
||||
):
|
||||
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
|
||||
if isinstance(p, dict):
|
||||
raise RuntimeError("Dict parameter not supported")
|
||||
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
|
||||
# Copy BatchNorm parameters, and non-trainable parameters directly.
|
||||
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
|
||||
with torch.no_grad():
|
||||
p_ema.mul_(alpha)
|
||||
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
|
||||
|
||||
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
|
|
Loading…
Reference in New Issue