Fixes following #670 (#719)

This commit is contained in:
Simon Alibert 2025-02-12 12:53:55 +01:00 committed by GitHub
parent 90e099b39f
commit e71095960f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 8 additions and 7 deletions

View File

@ -300,7 +300,7 @@ class PI0Policy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1)) self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft() return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss""" """Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha: if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
@ -328,12 +328,12 @@ class PI0Policy(PreTrainedPolicy):
losses = losses[:, :, : self.config.max_action_dim] losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone() loss_dict["losses_after_rm_padding"] = losses.clone()
loss = losses.mean()
# For backward pass # For backward pass
loss_dict["loss"] = loss loss = losses.mean()
# For logging # For logging
loss_dict["l2_loss"] = loss.item() loss_dict["l2_loss"] = loss.item()
return loss_dict
return loss, loss_dict
def prepare_images(self, batch): def prepare_images(self, batch):
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and

View File

@ -102,7 +102,7 @@ class WandBLogger:
self._wandb.log_artifact(artifact) self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int, mode: str = "train"): def log_dict(self, d: dict, step: int, mode: str = "train"):
if mode in {"train", "eval"}: if mode not in {"train", "eval"}:
raise ValueError(mode) raise ValueError(mode)
for k, v in d.items(): for k, v in d.items():
@ -114,7 +114,7 @@ class WandBLogger:
self._wandb.log({f"{mode}/{k}": v}, step=step) self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"): def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode in {"train", "eval"}: if mode not in {"train", "eval"}:
raise ValueError(mode) raise ValueError(mode)
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4") wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")

View File

@ -233,7 +233,7 @@ def train(cfg: TrainPipelineConfig):
logging.info(train_tracker) logging.info(train_tracker)
if wandb_logger: if wandb_logger:
wandb_log_dict = {**train_tracker.to_dict(), **output_dict} wandb_log_dict = {**train_tracker.to_dict(), **output_dict}
wandb_logger.log_dict(wandb_log_dict) wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages() train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step: if cfg.save_checkpoint and is_saving_step:
@ -271,6 +271,7 @@ def train(cfg: TrainPipelineConfig):
logging.info(eval_tracker) logging.info(eval_tracker)
if wandb_logger: if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval") wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
if eval_env: if eval_env: