From 31984645da7cee5defe76ec692e754932fd0bc03 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Thu, 21 Nov 2024 17:03:30 +0000 Subject: [PATCH] simplified estimate_value function in policy --- .../common/policies/tdmpc2/modeling_tdmpc2.py | 56 +++++-------------- .../common/policies/tdmpc2/tdmpc2_utils.py | 2 +- 2 files changed, 14 insertions(+), 44 deletions(-) diff --git a/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py b/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py index 5a454f14..27b3295d 100644 --- a/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py +++ b/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py @@ -39,8 +39,7 @@ from torch import Tensor from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config from lerobot.common.policies.utils import get_device_from_parameters, populate_queues -from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash - +from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash, soft_cross_entropy class TDMPC2Policy( nn.Module, @@ -110,6 +109,7 @@ class TDMPC2Policy( self._use_env_state = True self.scale = RunningScale(self.config.tau) + self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length self.reset() @@ -286,40 +286,16 @@ class TDMPC2Policy( # Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics # model. Keep track of return. for t in range(actions.shape[0]): - # We will compute the reward in a moment. First compute the uncertainty regularizer from eqn 4 - # of the FOWM paper. - if self.config.uncertainty_regularizer_coeff > 0: - regularization = -( - self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0) - ) - else: - regularization = 0 # Estimate the next state (latent) and reward. z, reward = self.model.latent_dynamics_and_reward(z, actions[t], discretize_reward=True) # Update the return and running discount. - G += running_discount * (reward + regularization) + G += running_discount * reward running_discount *= self.config.discount - # Add the estimated value of the final state (using the minimum for a conservative estimate). - # Do so by predicting the next action, then taking a minimum over the ensemble of state-action value - # estimators. - # Note: This small amount of added noise seems to help a bit at inference time as observed by success - # metrics over 50 episodes of xarm_lift_medium_replay. - next_action = self.model.pi(z, self.config.min_std)[0] # (batch, action_dim) - terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch) - # Randomly choose 2 of the Qs for terminal value estimation (as in App C. of the FOWM paper). - if self.config.q_ensemble_size > 2: - G += ( - running_discount - * torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[ - 0 - ] - ) - else: - G += running_discount * torch.min(terminal_values, dim=0)[0] - # Finally, also regularize the terminal value. - if self.config.uncertainty_regularizer_coeff > 0: - G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0) - return G + + #next_action = self.model.pi(z)[0] # (batch, action_dim) + #terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch) + + return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type='avg') def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: """Run the batch through the model and compute the loss. @@ -380,8 +356,9 @@ class TDMPC2Policy( # Compute various targets with stopgrad. with torch.no_grad(): - # Latent state consistency targets. + # Latent state consistency targets for consistency loss. z_targets = self.model.encode(next_observations) + # Compute the TD-target from a reward and the next observation pi = self.model.pi(z_targets)[0] td_targets = ( @@ -395,6 +372,7 @@ class TDMPC2Policy( temporal_loss_coeffs = torch.pow( self.config.temporal_decay_coeff, torch.arange(horizon, device=device) ).unsqueeze(-1) + # Compute consistency loss as MSE loss between latents predicted from the rollout and latents # predicted from the (target model's) observation encoder. consistency_loss = ( @@ -417,7 +395,6 @@ class TDMPC2Policy( temporal_loss_coeffs * soft_cross_entropy(reward_preds, reward, self.config) * ~batch["next.reward_is_pad"] - # `reward_preds` depends on the current observation and the actions. * ~batch["observation.state_is_pad"][0] * ~batch["action_is_pad"] ) @@ -457,15 +434,10 @@ class TDMPC2Policy( self.scale.update(qs[0]) qs = self.scale(qs) - rho = torch.pow(self.config.rho, torch.arange(len(qs), device=qs.device)).unsqueeze(-1) + rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(-1) - # mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b) - # NOTE: The original implementation does not take the sum over the temporal dimension like with the - # other losses. - # TODO(alexander-soare): Take the sum over the temporal dimension and check that training still works - # as well as expected. pi_loss = ( - (self.config.entropy_coef * log_pis - qs).mean(dim=-1) + (self.config.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho # * temporal_loss_coeffs # `action_preds` depends on the first observation and the actions. @@ -558,8 +530,6 @@ class TDMPC2WorldModel(nn.Module): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.Embedding): - nn.init.uniform_(m.weight, -0.02, 0.02) elif isinstance(m, nn.ParameterList): for i, p in enumerate(m): if p.dim() == 3: # Linear diff --git a/lerobot/common/policies/tdmpc2/tdmpc2_utils.py b/lerobot/common/policies/tdmpc2/tdmpc2_utils.py index ccbb16c3..e1a38eab 100644 --- a/lerobot/common/policies/tdmpc2/tdmpc2_utils.py +++ b/lerobot/common/policies/tdmpc2/tdmpc2_utils.py @@ -68,7 +68,7 @@ class NormedLinear(nn.Linear): f"act={self.act.__class__.__name__})" -def soft_ce(pred, target, cfg): +def soft_cross_entropy(pred, target, cfg): """Computes the cross entropy loss between predictions and soft targets.""" pred = F.log_softmax(pred, dim=-1) target = two_hot(target, cfg)