simplified estimate_value function in policy

This commit is contained in:
Michel Aractingi 2024-11-21 17:03:30 +00:00
parent c41ec08ec1
commit 31984645da
2 changed files with 14 additions and 44 deletions

View File

@ -39,8 +39,7 @@ from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash, soft_cross_entropy
class TDMPC2Policy(
nn.Module,
@ -110,6 +109,7 @@ class TDMPC2Policy(
self._use_env_state = True
self.scale = RunningScale(self.config.tau)
self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length
self.reset()
@ -286,40 +286,16 @@ class TDMPC2Policy(
# Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics
# model. Keep track of return.
for t in range(actions.shape[0]):
# We will compute the reward in a moment. First compute the uncertainty regularizer from eqn 4
# of the FOWM paper.
if self.config.uncertainty_regularizer_coeff > 0:
regularization = -(
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
)
else:
regularization = 0
# Estimate the next state (latent) and reward.
z, reward = self.model.latent_dynamics_and_reward(z, actions[t], discretize_reward=True)
# Update the return and running discount.
G += running_discount * (reward + regularization)
G += running_discount * reward
running_discount *= self.config.discount
# Add the estimated value of the final state (using the minimum for a conservative estimate).
# Do so by predicting the next action, then taking a minimum over the ensemble of state-action value
# estimators.
# Note: This small amount of added noise seems to help a bit at inference time as observed by success
# metrics over 50 episodes of xarm_lift_medium_replay.
next_action = self.model.pi(z, self.config.min_std)[0] # (batch, action_dim)
terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
# Randomly choose 2 of the Qs for terminal value estimation (as in App C. of the FOWM paper).
if self.config.q_ensemble_size > 2:
G += (
running_discount
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
0
]
)
else:
G += running_discount * torch.min(terminal_values, dim=0)[0]
# Finally, also regularize the terminal value.
if self.config.uncertainty_regularizer_coeff > 0:
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
return G
#next_action = self.model.pi(z)[0] # (batch, action_dim)
#terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type='avg')
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
@ -380,8 +356,9 @@ class TDMPC2Policy(
# Compute various targets with stopgrad.
with torch.no_grad():
# Latent state consistency targets.
# Latent state consistency targets for consistency loss.
z_targets = self.model.encode(next_observations)
# Compute the TD-target from a reward and the next observation
pi = self.model.pi(z_targets)[0]
td_targets = (
@ -395,6 +372,7 @@ class TDMPC2Policy(
temporal_loss_coeffs = torch.pow(
self.config.temporal_decay_coeff, torch.arange(horizon, device=device)
).unsqueeze(-1)
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
# predicted from the (target model's) observation encoder.
consistency_loss = (
@ -417,7 +395,6 @@ class TDMPC2Policy(
temporal_loss_coeffs
* soft_cross_entropy(reward_preds, reward, self.config)
* ~batch["next.reward_is_pad"]
# `reward_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
@ -457,15 +434,10 @@ class TDMPC2Policy(
self.scale.update(qs[0])
qs = self.scale(qs)
rho = torch.pow(self.config.rho, torch.arange(len(qs), device=qs.device)).unsqueeze(-1)
rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(-1)
# mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
# other losses.
# TODO(alexander-soare): Take the sum over the temporal dimension and check that training still works
# as well as expected.
pi_loss = (
(self.config.entropy_coef * log_pis - qs).mean(dim=-1)
(self.config.entropy_coef * log_pis - qs).mean(dim=(1,2))
* rho
# * temporal_loss_coeffs
# `action_preds` depends on the first observation and the actions.
@ -558,8 +530,6 @@ class TDMPC2WorldModel(nn.Module):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Embedding):
nn.init.uniform_(m.weight, -0.02, 0.02)
elif isinstance(m, nn.ParameterList):
for i, p in enumerate(m):
if p.dim() == 3: # Linear

View File

@ -68,7 +68,7 @@ class NormedLinear(nn.Linear):
f"act={self.act.__class__.__name__})"
def soft_ce(pred, target, cfg):
def soft_cross_entropy(pred, target, cfg):
"""Computes the cross entropy loss between predictions and soft targets."""
pred = F.log_softmax(pred, dim=-1)
target = two_hot(target, cfg)