change loss to loss_dict (vq-bet policy output name)

This commit is contained in:
jayLEE0301 2024-06-04 16:59:02 -04:00
parent 0657c0c6c1
commit 6012b4c859
1 changed files with 4 additions and 4 deletions

View File

@ -113,9 +113,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin ped header / offset header parts.
_, loss = self.vqbet(batch, rollout=False)
_, loss_dict = self.vqbet(batch, rollout=False)
return loss
return loss_dict
class SpatialSoftmax(nn.Module):
"""
@ -326,12 +326,12 @@ class VQBeTModel(nn.Module):
output[:, i, :, :] = action[:, i : i + act_w, :]
action = output
loss = self.action_head.loss_fn(
loss_dict = self.action_head.loss_fn(
pred_action,
action,
reduction="mean",
)
return pred_action, loss
return pred_action, loss_dict
class VQBeTHead(nn.Module):