use loggable data types in act return dict
This commit is contained in:
parent
e1addd40f4
commit
0dc6c7265e
|
@ -101,7 +101,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
).mean()
|
).mean()
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss}
|
loss_dict = {"l1_loss": l1_loss.item()}
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
|
@ -110,7 +110,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
mean_kld = (
|
mean_kld = (
|
||||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||||
)
|
)
|
||||||
loss_dict["kld_loss"] = mean_kld
|
loss_dict["kld_loss"] = mean_kld.item()
|
||||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||||
else:
|
else:
|
||||||
loss_dict["loss"] = l1_loss
|
loss_dict["loss"] = l1_loss
|
||||||
|
|
Loading…
Reference in New Issue