Enable logging all the information returned by the `forward` methods of policies (#151)
This commit is contained in:
parent
b187942db4
commit
1249aee3ac
|
@ -114,6 +114,11 @@ class Logger:
|
||||||
assert mode in {"train", "eval"}
|
assert mode in {"train", "eval"}
|
||||||
if self._wandb is not None:
|
if self._wandb is not None:
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
|
if not isinstance(v, (int, float, str)):
|
||||||
|
logging.warning(
|
||||||
|
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||||
|
)
|
||||||
|
continue
|
||||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||||
|
|
||||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||||
|
|
|
@ -101,7 +101,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
).mean()
|
).mean()
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss}
|
loss_dict = {"l1_loss": l1_loss.item()}
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
|
@ -110,7 +110,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
mean_kld = (
|
mean_kld = (
|
||||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||||
)
|
)
|
||||||
loss_dict["kld_loss"] = mean_kld
|
loss_dict["kld_loss"] = mean_kld.item()
|
||||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||||
else:
|
else:
|
||||||
loss_dict["loss"] = l1_loss
|
loss_dict["loss"] = l1_loss
|
||||||
|
|
|
@ -38,7 +38,8 @@ class Policy(Protocol):
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||||
"""Run the batch through the model and compute the loss for training or validation.
|
"""Run the batch through the model and compute the loss for training or validation.
|
||||||
|
|
||||||
Returns a dictionary with "loss" and maybe other information.
|
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
||||||
|
other items should be logging-friendly, native Python types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def select_action(self, batch: dict[str, Tensor]):
|
def select_action(self, batch: dict[str, Tensor]):
|
||||||
|
|
|
@ -72,6 +72,7 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||||
|
|
||||||
|
|
||||||
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||||
|
"""Returns a dictionary of items for logging."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
policy.train()
|
policy.train()
|
||||||
output_dict = policy.forward(batch)
|
output_dict = policy.forward(batch)
|
||||||
|
@ -99,6 +100,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||||
"grad_norm": float(grad_norm),
|
"grad_norm": float(grad_norm),
|
||||||
"lr": optimizer.param_groups[0]["lr"],
|
"lr": optimizer.param_groups[0]["lr"],
|
||||||
"update_s": time.time() - start_time,
|
"update_s": time.time() - start_time,
|
||||||
|
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||||
}
|
}
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
@ -122,7 +124,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
|
||||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||||
|
|
||||||
|
|
||||||
def log_train_info(logger, info, step, cfg, dataset, is_offline):
|
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||||
loss = info["loss"]
|
loss = info["loss"]
|
||||||
grad_norm = info["grad_norm"]
|
grad_norm = info["grad_norm"]
|
||||||
lr = info["lr"]
|
lr = info["lr"]
|
||||||
|
|
Binary file not shown.
Loading…
Reference in New Issue