parent
90e099b39f
commit
e71095960f
|
@ -300,7 +300,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||||
self._action_queue.extend(actions.transpose(0, 1))
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
|
||||||
"""Do a full training forward pass to compute the loss"""
|
"""Do a full training forward pass to compute the loss"""
|
||||||
if self.config.adapt_to_pi_aloha:
|
if self.config.adapt_to_pi_aloha:
|
||||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||||
|
@ -328,12 +328,12 @@ class PI0Policy(PreTrainedPolicy):
|
||||||
losses = losses[:, :, : self.config.max_action_dim]
|
losses = losses[:, :, : self.config.max_action_dim]
|
||||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||||
|
|
||||||
loss = losses.mean()
|
|
||||||
# For backward pass
|
# For backward pass
|
||||||
loss_dict["loss"] = loss
|
loss = losses.mean()
|
||||||
# For logging
|
# For logging
|
||||||
loss_dict["l2_loss"] = loss.item()
|
loss_dict["l2_loss"] = loss.item()
|
||||||
return loss_dict
|
|
||||||
|
return loss, loss_dict
|
||||||
|
|
||||||
def prepare_images(self, batch):
|
def prepare_images(self, batch):
|
||||||
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||||||
|
|
|
@ -102,7 +102,7 @@ class WandBLogger:
|
||||||
self._wandb.log_artifact(artifact)
|
self._wandb.log_artifact(artifact)
|
||||||
|
|
||||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||||
if mode in {"train", "eval"}:
|
if mode not in {"train", "eval"}:
|
||||||
raise ValueError(mode)
|
raise ValueError(mode)
|
||||||
|
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
|
@ -114,7 +114,7 @@ class WandBLogger:
|
||||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||||
|
|
||||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||||
if mode in {"train", "eval"}:
|
if mode not in {"train", "eval"}:
|
||||||
raise ValueError(mode)
|
raise ValueError(mode)
|
||||||
|
|
||||||
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
|
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
|
||||||
|
|
|
@ -233,7 +233,7 @@ def train(cfg: TrainPipelineConfig):
|
||||||
logging.info(train_tracker)
|
logging.info(train_tracker)
|
||||||
if wandb_logger:
|
if wandb_logger:
|
||||||
wandb_log_dict = {**train_tracker.to_dict(), **output_dict}
|
wandb_log_dict = {**train_tracker.to_dict(), **output_dict}
|
||||||
wandb_logger.log_dict(wandb_log_dict)
|
wandb_logger.log_dict(wandb_log_dict, step)
|
||||||
train_tracker.reset_averages()
|
train_tracker.reset_averages()
|
||||||
|
|
||||||
if cfg.save_checkpoint and is_saving_step:
|
if cfg.save_checkpoint and is_saving_step:
|
||||||
|
@ -271,6 +271,7 @@ def train(cfg: TrainPipelineConfig):
|
||||||
logging.info(eval_tracker)
|
logging.info(eval_tracker)
|
||||||
if wandb_logger:
|
if wandb_logger:
|
||||||
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
||||||
|
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||||
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||||
|
|
||||||
if eval_env:
|
if eval_env:
|
||||||
|
|
Loading…
Reference in New Issue