change loss to loss_dict (vq-bet policy output name)
This commit is contained in:
parent
0657c0c6c1
commit
6012b4c859
|
@ -113,9 +113,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
|
||||
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
|
||||
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin ped header / offset header parts.
|
||||
_, loss = self.vqbet(batch, rollout=False)
|
||||
_, loss_dict = self.vqbet(batch, rollout=False)
|
||||
|
||||
return loss
|
||||
return loss_dict
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
|
@ -326,12 +326,12 @@ class VQBeTModel(nn.Module):
|
|||
output[:, i, :, :] = action[:, i : i + act_w, :]
|
||||
action = output
|
||||
|
||||
loss = self.action_head.loss_fn(
|
||||
loss_dict = self.action_head.loss_fn(
|
||||
pred_action,
|
||||
action,
|
||||
reduction="mean",
|
||||
)
|
||||
return pred_action, loss
|
||||
return pred_action, loss_dict
|
||||
|
||||
|
||||
class VQBeTHead(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue