remove self.model_target and added a target q ensemble only without the need to copy the
entire policy
This commit is contained in:
parent
a146544765
commit
c41ec08ec1
|
@ -39,7 +39,7 @@ from torch import Tensor
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
|
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
|
||||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||||
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv
|
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash
|
||||||
|
|
||||||
|
|
||||||
class TDMPC2Policy(
|
class TDMPC2Policy(
|
||||||
|
@ -84,9 +84,6 @@ class TDMPC2Policy(
|
||||||
config = TDMPC2Config()
|
config = TDMPC2Config()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = TDMPC2WorldModel(config)
|
self.model = TDMPC2WorldModel(config)
|
||||||
self.model_target = deepcopy(self.model)
|
|
||||||
for param in self.model_target.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
if config.input_normalization_modes is not None:
|
if config.input_normalization_modes is not None:
|
||||||
self.normalize_inputs = Normalize(
|
self.normalize_inputs = Normalize(
|
||||||
|
@ -384,12 +381,12 @@ class TDMPC2Policy(
|
||||||
# Compute various targets with stopgrad.
|
# Compute various targets with stopgrad.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Latent state consistency targets.
|
# Latent state consistency targets.
|
||||||
z_targets = self.model_target.encode(next_observations)
|
z_targets = self.model.encode(next_observations)
|
||||||
# Compute the TD-target from a reward and the next observation
|
# Compute the TD-target from a reward and the next observation
|
||||||
pi = self.model.pi(z_targets)[0]
|
pi = self.model.pi(z_targets)[0]
|
||||||
td_targets = (
|
td_targets = (
|
||||||
reward
|
reward
|
||||||
+ self.config.discount * self.model_target.Qs(z_targets, pi, return_type="min").squeeze()
|
+ self.config.discount * self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute losses.
|
# Compute losses.
|
||||||
|
@ -450,10 +447,15 @@ class TDMPC2Policy(
|
||||||
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
||||||
# We won't need these gradients again so detach.
|
# We won't need these gradients again so detach.
|
||||||
z_preds = z_preds.detach()
|
z_preds = z_preds.detach()
|
||||||
|
self.model.change_q_grad(mode=False)
|
||||||
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
|
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
|
||||||
qs = self.model_target.Qs(z_preds[:-1], action_preds, return_type="avg")
|
|
||||||
self.scale.update(qs[0])
|
with torch.no_grad():
|
||||||
qs = self.scale(qs)
|
# avoid unnessecary computation of the gradients during policy optimization
|
||||||
|
# TODO (michel-aractingi): the same logic should be extended when adding task embeddings
|
||||||
|
qs = self.model.Qs(z_preds[:-1], action_preds, return_type="avg")
|
||||||
|
self.scale.update(qs[0])
|
||||||
|
qs = self.scale(qs)
|
||||||
|
|
||||||
rho = torch.pow(self.config.rho, torch.arange(len(qs), device=qs.device)).unsqueeze(-1)
|
rho = torch.pow(self.config.rho, torch.arange(len(qs), device=qs.device)).unsqueeze(-1)
|
||||||
|
|
||||||
|
@ -498,12 +500,8 @@ class TDMPC2Policy(
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
"""Update the target model's parameters with an EMA step."""
|
"""Update the target model's using polyak averaging."""
|
||||||
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
|
self.model.update_target_Q()
|
||||||
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
|
|
||||||
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
|
|
||||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
|
||||||
|
|
||||||
|
|
||||||
class TDMPC2WorldModel(nn.Module):
|
class TDMPC2WorldModel(nn.Module):
|
||||||
"""Latent dynamics model used in TD-MPC2."""
|
"""Latent dynamics model used in TD-MPC2."""
|
||||||
|
@ -586,6 +584,11 @@ class TDMPC2WorldModel(nn.Module):
|
||||||
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
|
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
|
||||||
self.bins = self.bins.to(*args, **kwargs)
|
self.bins = self.bins.to(*args, **kwargs)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def train(self, mode):
|
||||||
|
super().train(mode)
|
||||||
|
self._target_Qs.train(False)
|
||||||
|
return self
|
||||||
|
|
||||||
def encode(self, obs: dict[str, Tensor]) -> Tensor:
|
def encode(self, obs: dict[str, Tensor]) -> Tensor:
|
||||||
"""Encodes an observation into its latent representation."""
|
"""Encodes an observation into its latent representation."""
|
||||||
|
@ -622,7 +625,7 @@ class TDMPC2WorldModel(nn.Module):
|
||||||
x = torch.cat([z, a], dim=-1)
|
x = torch.cat([z, a], dim=-1)
|
||||||
return self._dynamics(x)
|
return self._dynamics(x)
|
||||||
|
|
||||||
def pi(self, z: Tensor, std: float = 0.0) -> Tensor:
|
def pi(self, z: Tensor) -> Tensor:
|
||||||
"""Samples an action from the learned policy.
|
"""Samples an action from the learned policy.
|
||||||
|
|
||||||
The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when
|
The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when
|
||||||
|
@ -668,6 +671,14 @@ class TDMPC2WorldModel(nn.Module):
|
||||||
Q1, Q2 = two_hot_inv(Q1, self.bins), two_hot_inv(Q2, self.bins)
|
Q1, Q2 = two_hot_inv(Q1, self.bins), two_hot_inv(Q2, self.bins)
|
||||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||||
|
|
||||||
|
def update_target_Q(self):
|
||||||
|
"""
|
||||||
|
Soft-update target Q-networks using Polyak averaging.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
|
||||||
|
p_target.data.lerp_(p.data, self.config.target_model_momentum)
|
||||||
|
|
||||||
|
|
||||||
class TDMPC2ObservationEncoder(nn.Module):
|
class TDMPC2ObservationEncoder(nn.Module):
|
||||||
"""Encode image and/or state vector observations."""
|
"""Encode image and/or state vector observations."""
|
||||||
|
@ -777,23 +788,6 @@ def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor:
|
||||||
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
|
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
|
||||||
|
|
||||||
|
|
||||||
def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
|
|
||||||
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
|
|
||||||
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
|
|
||||||
for (n_p_ema, p_ema), (n_p, p) in zip(
|
|
||||||
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
|
|
||||||
):
|
|
||||||
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
|
|
||||||
if isinstance(p, dict):
|
|
||||||
raise RuntimeError("Dict parameter not supported")
|
|
||||||
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
|
|
||||||
# Copy BatchNorm parameters, and non-trainable parameters directly.
|
|
||||||
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
|
|
||||||
with torch.no_grad():
|
|
||||||
p_ema.mul_(alpha)
|
|
||||||
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
|
|
||||||
|
|
||||||
|
|
||||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue