Enable logging all the information returned by the `forward` methods of policies (#151)
This commit is contained in:
parent
b187942db4
commit
1249aee3ac
|
@ -114,6 +114,11 @@ class Logger:
|
|||
assert mode in {"train", "eval"}
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
|
|
|
@ -101,7 +101,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss}
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
if self.config.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
|
@ -110,7 +110,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
|
|
|
@ -38,7 +38,8 @@ class Policy(Protocol):
|
|||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||
"""Run the batch through the model and compute the loss for training or validation.
|
||||
|
||||
Returns a dictionary with "loss" and maybe other information.
|
||||
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
||||
other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
|
|
|
@ -72,6 +72,7 @@ def make_optimizer_and_scheduler(cfg, policy):
|
|||
|
||||
|
||||
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||
"""Returns a dictionary of items for logging."""
|
||||
start_time = time.time()
|
||||
policy.train()
|
||||
output_dict = policy.forward(batch)
|
||||
|
@ -99,6 +100,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
|||
"grad_norm": float(grad_norm),
|
||||
"lr": optimizer.param_groups[0]["lr"],
|
||||
"update_s": time.time() - start_time,
|
||||
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||
}
|
||||
|
||||
return info
|
||||
|
@ -122,7 +124,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
|
|||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
def log_train_info(logger, info, step, cfg, dataset, is_offline):
|
||||
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||
loss = info["loss"]
|
||||
grad_norm = info["grad_norm"]
|
||||
lr = info["lr"]
|
||||
|
|
Binary file not shown.
Loading…
Reference in New Issue