Merge 25914596c9
into a06598678c
This commit is contained in:
commit
ed69dd68af
|
@ -34,6 +34,10 @@ from torch.optim.lr_scheduler import LRScheduler
|
|||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
|
||||
# local logging reqs
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import imageio.v3 as iio
|
||||
import numpy as np
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
@ -131,6 +135,10 @@ class Logger:
|
|||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
self._tensorboard = None
|
||||
if self._cfg.tensorboard.enable:
|
||||
self._tensorboard = SummaryWriter(log_dir=self._log_dir)
|
||||
logging.info(colored("Tensorboard logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
|
||||
@classmethod
|
||||
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
|
||||
|
@ -238,8 +246,21 @@ class Logger:
|
|||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
if self._tensorboard is not None:
|
||||
for k, v in d.items():
|
||||
self._tensorboard.add_scalar(f"{mode}/{k}", v, global_step=step)
|
||||
self._tensorboard.flush()
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
if self._wandb is not None:
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
if self._tensorboard is not None:
|
||||
# Read video file and convert it to tensorboard format
|
||||
frames = iio.imread(video_path, plugin="pyav")
|
||||
video_np = np.array(list(frames))
|
||||
T, H, W, C = video_np.shape
|
||||
# Transpose the channel position and add a leading 1 for the batch dimension expected by TF
|
||||
video_np = video_np.transpose(0, 3, 1, 2).reshape(1, T, C, H, W)
|
||||
self._tensorboard.add_video(f"{mode}/video", video_np, step, fps=self._cfg.fps)
|
||||
|
|
|
@ -57,3 +57,6 @@ wandb:
|
|||
disable_artifact: false
|
||||
project: lerobot
|
||||
notes: ""
|
||||
|
||||
tensorboard:
|
||||
enable: false
|
||||
|
|
|
@ -342,7 +342,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
start_seed=cfg.seed,
|
||||
)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
|
||||
if cfg.wandb.enable:
|
||||
if cfg.wandb.enable or cfg.tensorboard.enable:
|
||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
|
|
Loading…
Reference in New Issue