Enable logging all the information returned by the `forward` methods of policies (#151)

This commit is contained in:
Alexander Soare 2024-05-10 07:45:32 +01:00 committed by GitHub
parent b187942db4
commit 1249aee3ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 12 additions and 4 deletions

View File

@ -114,6 +114,11 @@ class Logger:
assert mode in {"train", "eval"} assert mode in {"train", "eval"}
if self._wandb is not None: if self._wandb is not None:
for k, v in d.items(): for k, v in d.items():
if not isinstance(v, (int, float, str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step) self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"): def log_video(self, video_path: str, step: int, mode: str = "train"):

View File

@ -101,7 +101,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean() ).mean()
loss_dict = {"l1_loss": l1_loss} loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae: if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total # each dimension independently, we sum over the latent dimension to get the total
@ -110,7 +110,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
mean_kld = ( mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
) )
loss_dict["kld_loss"] = mean_kld loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
else: else:
loss_dict["loss"] = l1_loss loss_dict["loss"] = l1_loss

View File

@ -38,7 +38,8 @@ class Policy(Protocol):
def forward(self, batch: dict[str, Tensor]) -> dict: def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation. """Run the batch through the model and compute the loss for training or validation.
Returns a dictionary with "loss" and maybe other information. Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
""" """
def select_action(self, batch: dict[str, Tensor]): def select_action(self, batch: dict[str, Tensor]):

View File

@ -72,6 +72,7 @@ def make_optimizer_and_scheduler(cfg, policy):
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"""Returns a dictionary of items for logging."""
start_time = time.time() start_time = time.time()
policy.train() policy.train()
output_dict = policy.forward(batch) output_dict = policy.forward(batch)
@ -99,6 +100,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"grad_norm": float(grad_norm), "grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"], "lr": optimizer.param_groups[0]["lr"],
"update_s": time.time() - start_time, "update_s": time.time() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
} }
return info return info
@ -122,7 +124,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
train(cfg, out_dir=out_dir, job_name=job_name) train(cfg, out_dir=out_dir, job_name=job_name)
def log_train_info(logger, info, step, cfg, dataset, is_offline): def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
loss = info["loss"] loss = info["loss"]
grad_norm = info["grad_norm"] grad_norm = info["grad_norm"]
lr = info["lr"] lr = info["lr"]