change loss fn name, remove unnecessary resizing part

This commit is contained in:
jayLEE0301 2024-06-05 13:36:54 -04:00
parent eaf0af0ac6
commit 975da28461
1 changed files with 8 additions and 11 deletions

View File

@ -376,7 +376,7 @@ class VQBeTHead(nn.Module):
# init vqvae
self.vqvae_model = VqVae(config)
# loss
self._criterion = FocalLoss(gamma=2.0)
self._focal_loss_fn = FocalLoss(gamma=2.0)
def discretize(self, discretize_step, actions):
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
@ -515,15 +515,15 @@ class VQBeTHead(nn.Module):
# Now we can compute the loss.
# offset loss is L1 distance between the predicted action and ground truth action
offset_loss = torch.nn.L1Loss()(action_seq, predicted_action)
offset_loss = F.l1_loss(action_seq, predicted_action)
# calculate primary code prediction loss
cbet_loss1 = self._criterion( # F.cross_entropy
cbet_loss1 = self._focal_loss_fn(
cbet_logits[:, 0, :],
action_bins[:, 0],
)
# calculate secondary code prediction loss
cbet_loss2 = self._criterion( # F.cross_entropy
cbet_loss2 = self._focal_loss_fn(
cbet_logits[:, 1, :],
action_bins[:, 1],
)
@ -537,13 +537,10 @@ class VQBeTHead(nn.Module):
(action_bins[:, 1] == sampled_centers[:, 1]).int()
) / (NT)
action_mse_error = F.mse_loss(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T),
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T),
)
vq_action_error = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T))).mean()
offset_action_error = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T))).mean()
action_error_max = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T))).max()
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
vq_action_error = torch.mean(action_seq - decoded_action)
offset_action_error = torch.mean(action_seq - predicted_action)
action_error_max = torch.max(action_seq - predicted_action)
loss = cbet_loss + self.config.offset_loss_weight * offset_loss