remove unused parts
This commit is contained in:
parent
f69d552480
commit
eaf0af0ac6
|
@ -476,7 +476,7 @@ class VQBeTHead(nn.Module):
|
|||
)
|
||||
|
||||
return {
|
||||
"cbet_logits": cbet_logits if "cbet_logits" in locals() else None,
|
||||
"cbet_logits": cbet_logits,
|
||||
"predicted_action": predicted_action,
|
||||
"sampled_centers": sampled_centers,
|
||||
"decoded_action": decoded_action,
|
||||
|
@ -513,8 +513,6 @@ class VQBeTHead(nn.Module):
|
|||
) # action_bins: NT, G
|
||||
|
||||
# Now we can compute the loss.
|
||||
if action_seq.ndim == 2:
|
||||
action_seq = action_seq.unsqueeze(0)
|
||||
|
||||
# offset loss is L1 distance between the predicted action and ground truth action
|
||||
offset_loss = torch.nn.L1Loss()(action_seq, predicted_action)
|
||||
|
|
Loading…
Reference in New Issue