This commit is contained in:
dirkmcpherson 2024-06-10 13:27:31 +01:00 committed by GitHub
commit ed69dd68af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 3 deletions

View File

@ -34,6 +34,10 @@ from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
# local logging reqs
from torch.utils.tensorboard import SummaryWriter
import imageio.v3 as iio
import numpy as np
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
@ -131,6 +135,10 @@ class Logger:
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
self._tensorboard = None
if self._cfg.tensorboard.enable:
self._tensorboard = SummaryWriter(log_dir=self._log_dir)
logging.info(colored("Tensorboard logs will be saved locally.", "yellow", attrs=["bold"]))
@classmethod
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
@ -238,8 +246,21 @@ class Logger:
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step)
if self._tensorboard is not None:
for k, v in d.items():
self._tensorboard.add_scalar(f"{mode}/{k}", v, global_step=step)
self._tensorboard.flush()
def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"}
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
if self._wandb is not None:
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
if self._tensorboard is not None:
# Read video file and convert it to tensorboard format
frames = iio.imread(video_path, plugin="pyav")
video_np = np.array(list(frames))
T, H, W, C = video_np.shape
# Transpose the channel position and add a leading 1 for the batch dimension expected by TF
video_np = video_np.transpose(0, 3, 1, 2).reshape(1, T, C, H, W)
self._tensorboard.add_video(f"{mode}/video", video_np, step, fps=self._cfg.fps)

View File

@ -57,3 +57,6 @@ wandb:
disable_artifact: false
project: lerobot
notes: ""
tensorboard:
enable: false

View File

@ -342,7 +342,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
if cfg.wandb.enable:
if cfg.wandb.enable or cfg.tensorboard.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")