remove unused parts

This commit is contained in:
jayLEE0301 2024-06-05 13:09:37 -04:00
parent f69d552480
commit eaf0af0ac6
1 changed files with 1 additions and 3 deletions

View File

@ -476,7 +476,7 @@ class VQBeTHead(nn.Module):
)
return {
"cbet_logits": cbet_logits if "cbet_logits" in locals() else None,
"cbet_logits": cbet_logits,
"predicted_action": predicted_action,
"sampled_centers": sampled_centers,
"decoded_action": decoded_action,
@ -513,8 +513,6 @@ class VQBeTHead(nn.Module):
) # action_bins: NT, G
# Now we can compute the loss.
if action_seq.ndim == 2:
action_seq = action_seq.unsqueeze(0)
# offset loss is L1 distance between the predicted action and ground truth action
offset_loss = torch.nn.L1Loss()(action_seq, predicted_action)