simplified estimate_value function in policy
This commit is contained in:
parent
c41ec08ec1
commit
31984645da
|
@ -39,8 +39,7 @@ from torch import Tensor
|
|||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash
|
||||
|
||||
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash, soft_cross_entropy
|
||||
|
||||
class TDMPC2Policy(
|
||||
nn.Module,
|
||||
|
@ -110,6 +109,7 @@ class TDMPC2Policy(
|
|||
self._use_env_state = True
|
||||
|
||||
self.scale = RunningScale(self.config.tau)
|
||||
self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length
|
||||
|
||||
self.reset()
|
||||
|
||||
|
@ -286,40 +286,16 @@ class TDMPC2Policy(
|
|||
# Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics
|
||||
# model. Keep track of return.
|
||||
for t in range(actions.shape[0]):
|
||||
# We will compute the reward in a moment. First compute the uncertainty regularizer from eqn 4
|
||||
# of the FOWM paper.
|
||||
if self.config.uncertainty_regularizer_coeff > 0:
|
||||
regularization = -(
|
||||
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
|
||||
)
|
||||
else:
|
||||
regularization = 0
|
||||
# Estimate the next state (latent) and reward.
|
||||
z, reward = self.model.latent_dynamics_and_reward(z, actions[t], discretize_reward=True)
|
||||
# Update the return and running discount.
|
||||
G += running_discount * (reward + regularization)
|
||||
G += running_discount * reward
|
||||
running_discount *= self.config.discount
|
||||
# Add the estimated value of the final state (using the minimum for a conservative estimate).
|
||||
# Do so by predicting the next action, then taking a minimum over the ensemble of state-action value
|
||||
# estimators.
|
||||
# Note: This small amount of added noise seems to help a bit at inference time as observed by success
|
||||
# metrics over 50 episodes of xarm_lift_medium_replay.
|
||||
next_action = self.model.pi(z, self.config.min_std)[0] # (batch, action_dim)
|
||||
terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
|
||||
# Randomly choose 2 of the Qs for terminal value estimation (as in App C. of the FOWM paper).
|
||||
if self.config.q_ensemble_size > 2:
|
||||
G += (
|
||||
running_discount
|
||||
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
|
||||
0
|
||||
]
|
||||
)
|
||||
else:
|
||||
G += running_discount * torch.min(terminal_values, dim=0)[0]
|
||||
# Finally, also regularize the terminal value.
|
||||
if self.config.uncertainty_regularizer_coeff > 0:
|
||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
return G
|
||||
|
||||
#next_action = self.model.pi(z)[0] # (batch, action_dim)
|
||||
#terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
|
||||
|
||||
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type='avg')
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
@ -380,8 +356,9 @@ class TDMPC2Policy(
|
|||
|
||||
# Compute various targets with stopgrad.
|
||||
with torch.no_grad():
|
||||
# Latent state consistency targets.
|
||||
# Latent state consistency targets for consistency loss.
|
||||
z_targets = self.model.encode(next_observations)
|
||||
|
||||
# Compute the TD-target from a reward and the next observation
|
||||
pi = self.model.pi(z_targets)[0]
|
||||
td_targets = (
|
||||
|
@ -395,6 +372,7 @@ class TDMPC2Policy(
|
|||
temporal_loss_coeffs = torch.pow(
|
||||
self.config.temporal_decay_coeff, torch.arange(horizon, device=device)
|
||||
).unsqueeze(-1)
|
||||
|
||||
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
|
||||
# predicted from the (target model's) observation encoder.
|
||||
consistency_loss = (
|
||||
|
@ -417,7 +395,6 @@ class TDMPC2Policy(
|
|||
temporal_loss_coeffs
|
||||
* soft_cross_entropy(reward_preds, reward, self.config)
|
||||
* ~batch["next.reward_is_pad"]
|
||||
# `reward_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
|
@ -457,15 +434,10 @@ class TDMPC2Policy(
|
|||
self.scale.update(qs[0])
|
||||
qs = self.scale(qs)
|
||||
|
||||
rho = torch.pow(self.config.rho, torch.arange(len(qs), device=qs.device)).unsqueeze(-1)
|
||||
rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(-1)
|
||||
|
||||
# mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
|
||||
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
|
||||
# other losses.
|
||||
# TODO(alexander-soare): Take the sum over the temporal dimension and check that training still works
|
||||
# as well as expected.
|
||||
pi_loss = (
|
||||
(self.config.entropy_coef * log_pis - qs).mean(dim=-1)
|
||||
(self.config.entropy_coef * log_pis - qs).mean(dim=(1,2))
|
||||
* rho
|
||||
# * temporal_loss_coeffs
|
||||
# `action_preds` depends on the first observation and the actions.
|
||||
|
@ -558,8 +530,6 @@ class TDMPC2WorldModel(nn.Module):
|
|||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Embedding):
|
||||
nn.init.uniform_(m.weight, -0.02, 0.02)
|
||||
elif isinstance(m, nn.ParameterList):
|
||||
for i, p in enumerate(m):
|
||||
if p.dim() == 3: # Linear
|
||||
|
|
|
@ -68,7 +68,7 @@ class NormedLinear(nn.Linear):
|
|||
f"act={self.act.__class__.__name__})"
|
||||
|
||||
|
||||
def soft_ce(pred, target, cfg):
|
||||
def soft_cross_entropy(pred, target, cfg):
|
||||
"""Computes the cross entropy loss between predictions and soft targets."""
|
||||
pred = F.log_softmax(pred, dim=-1)
|
||||
target = two_hot(target, cfg)
|
||||
|
|
Loading…
Reference in New Issue