change loss fn name, remove unnecessary resizing part
This commit is contained in:
parent
eaf0af0ac6
commit
975da28461
|
@ -376,7 +376,7 @@ class VQBeTHead(nn.Module):
|
|||
# init vqvae
|
||||
self.vqvae_model = VqVae(config)
|
||||
# loss
|
||||
self._criterion = FocalLoss(gamma=2.0)
|
||||
self._focal_loss_fn = FocalLoss(gamma=2.0)
|
||||
|
||||
def discretize(self, discretize_step, actions):
|
||||
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
|
||||
|
@ -515,15 +515,15 @@ class VQBeTHead(nn.Module):
|
|||
# Now we can compute the loss.
|
||||
|
||||
# offset loss is L1 distance between the predicted action and ground truth action
|
||||
offset_loss = torch.nn.L1Loss()(action_seq, predicted_action)
|
||||
offset_loss = F.l1_loss(action_seq, predicted_action)
|
||||
|
||||
# calculate primary code prediction loss
|
||||
cbet_loss1 = self._criterion( # F.cross_entropy
|
||||
cbet_loss1 = self._focal_loss_fn(
|
||||
cbet_logits[:, 0, :],
|
||||
action_bins[:, 0],
|
||||
)
|
||||
# calculate secondary code prediction loss
|
||||
cbet_loss2 = self._criterion( # F.cross_entropy
|
||||
cbet_loss2 = self._focal_loss_fn(
|
||||
cbet_logits[:, 1, :],
|
||||
action_bins[:, 1],
|
||||
)
|
||||
|
@ -537,13 +537,10 @@ class VQBeTHead(nn.Module):
|
|||
(action_bins[:, 1] == sampled_centers[:, 1]).int()
|
||||
) / (NT)
|
||||
|
||||
action_mse_error = F.mse_loss(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T),
|
||||
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T),
|
||||
)
|
||||
vq_action_error = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T))).mean()
|
||||
offset_action_error = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T))).mean()
|
||||
action_error_max = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T))).max()
|
||||
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
|
||||
vq_action_error = torch.mean(action_seq - decoded_action)
|
||||
offset_action_error = torch.mean(action_seq - predicted_action)
|
||||
action_error_max = torch.max(action_seq - predicted_action)
|
||||
|
||||
loss = cbet_loss + self.config.offset_loss_weight * offset_loss
|
||||
|
||||
|
|
Loading…
Reference in New Issue