Allow for simultaneous wandb and tensorboard logging.
This commit is contained in:
parent
ccf2782d8a
commit
ba886ed437
|
@ -53,7 +53,6 @@ class Logger:
|
|||
run_offline = not enable_wandb or not project
|
||||
if run_offline:
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
self._local_writer = SummaryWriter(log_dir=self._log_dir)
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
|
@ -78,7 +77,10 @@ class Logger:
|
|||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
self._local_writer = None
|
||||
self._tensorboard = None
|
||||
if self._cfg.tensorboard.enable:
|
||||
self._tensorboard = SummaryWriter(log_dir=self._log_dir)
|
||||
logging.info(colored("Tensorboard logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
|
||||
def save_model(self, policy: Policy, identifier):
|
||||
if self._save_model:
|
||||
|
@ -127,21 +129,21 @@ class Logger:
|
|||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
elif self._local_writer is not None:
|
||||
if self._tensorboard is not None:
|
||||
for k, v in d.items():
|
||||
self._local_writer.add_scalar(f"{mode}/{k}", v, global_step=step)
|
||||
self._local_writer.flush()
|
||||
self._tensorboard.add_scalar(f"{mode}/{k}", v, global_step=step)
|
||||
self._tensorboard.flush()
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
if self._wandb is not None:
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
elif self._local_writer is not None:
|
||||
if self._tensorboard is not None:
|
||||
# Read video file and convert it to tensorboard format
|
||||
frames = iio.imread(video_path, plugin="pyav")
|
||||
video_np = np.array(list(frames))
|
||||
T, H, W, C = video_np.shape
|
||||
# Transpose the channel position and add a leading 1 for the batch dimension expected by TF
|
||||
video_np = video_np.transpose(0, 3, 1, 2).reshape(1, T, C, H, W)
|
||||
self._local_writer.add_video(f"{mode}/video", video_np, step, fps=self._cfg.fps)
|
||||
self._tensorboard.add_video(f"{mode}/video", video_np, step, fps=self._cfg.fps)
|
||||
|
|
|
@ -40,3 +40,6 @@ wandb:
|
|||
disable_artifact: false
|
||||
project: lerobot
|
||||
notes: ""
|
||||
|
||||
tensorboard:
|
||||
enable: false
|
||||
|
|
|
@ -351,7 +351,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
start_seed=cfg.seed,
|
||||
)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
if cfg.wandb.enable or cfg.tensorboard.enable:
|
||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if cfg.training.save_model and step % cfg.training.save_freq == 0:
|
||||
|
|
Loading…
Reference in New Issue