Add video recording to wandb logger

This commit is contained in:
Farbod Farshidian 2024-02-07 20:00:01 -05:00
parent e23bb6819a
commit 0a23fff25e
2 changed files with 42 additions and 3 deletions

View File

@ -210,6 +210,10 @@ class OnPolicyRunner:
"Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time "Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time
) )
# Video recording for wandb
if self.logger_type == "wandb":
self.writer.update_video_files(log_name="Video", fps=30.0)
str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m " str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m "
if len(locs["rewbuffer"]) > 0: if len(locs["rewbuffer"]) > 0:

View File

@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
import os import os
import json
import pathlib
from dataclasses import asdict from dataclasses import asdict
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -27,14 +29,15 @@ class WandbSummaryWriter(SummaryWriter):
try: try:
entity = os.environ["WANDB_USERNAME"] entity = os.environ["WANDB_USERNAME"]
except KeyError: except KeyError:
raise KeyError( entity = None
"Wandb username not found. Please run or add to ~/.bashrc: export WANDB_USERNAME=YOUR_USERNAME" print("Wandb username not found. Using default.")
)
wandb.init(project=project, entity=entity) wandb.init(project=project, entity=entity)
# Change generated name to project-number format # Change generated name to project-number format
wandb.run.name = project + wandb.run.name.split("-")[-1] wandb.run.name = project + wandb.run.name.split("-")[-1]
with open(os.path.join(log_dir, "wandb_info.json"), "w") as f:
json.dump({"wandb_run_id": wandb.run.id, "wandb_run_name": wandb.run.name}, f)
self.name_map = { self.name_map = {
"Train/mean_reward/time": "Train/mean_reward_time", "Train/mean_reward/time": "Train/mean_reward_time",
@ -45,12 +48,18 @@ class WandbSummaryWriter(SummaryWriter):
wandb.log({"log_dir": run_name}) wandb.log({"log_dir": run_name})
# Video logging.
self.saved_video_files = {}
def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
wandb.config.update({"runner_cfg": runner_cfg}) wandb.config.update({"runner_cfg": runner_cfg})
wandb.config.update({"policy_cfg": policy_cfg}) wandb.config.update({"policy_cfg": policy_cfg})
wandb.config.update({"alg_cfg": alg_cfg}) wandb.config.update({"alg_cfg": alg_cfg})
wandb.config.update({"env_cfg": asdict(env_cfg)}) wandb.config.update({"env_cfg": asdict(env_cfg)})
def get_config(self):
return wandb.config
def _map_path(self, path): def _map_path(self, path):
if path in self.name_map: if path in self.name_map:
return self.name_map[path] return self.name_map[path]
@ -67,6 +76,29 @@ class WandbSummaryWriter(SummaryWriter):
) )
wandb.log({self._map_path(tag): scalar_value}, step=global_step) wandb.log({self._map_path(tag): scalar_value}, step=global_step)
def update_video_files(self, log_name, fps: float):
# Check if there are new video files
log_dir = pathlib.Path(self.log_dir)
video_files = list(log_dir.rglob("*.mp4"))
for video_file in video_files:
file_size_kb = os.stat(str(video_file)).st_size / 1024
# If it is new file
if str(video_file) not in self.saved_video_files:
self.saved_video_files[str(video_file)] = {"size": file_size_kb, "added": False, "count": 0}
else:
# Only upload if the file size is not changing anymore to avoid uploading non-ready video.
video_info = self.saved_video_files[str(video_file)]
if video_info["size"] == file_size_kb and file_size_kb > 100 and not video_info["added"]:
if video_info["count"] > 10:
print(f"[Wandb] Uploading {os.path.basename(str(video_file))}.")
self.add_video(str(video_file), fps=fps, log_name=log_name)
self.saved_video_files[str(video_file)]["added"] = True
else:
video_info["count"] += 1
else:
self.saved_video_files[str(video_file)]["size"] = file_size_kb
video_info["count"] = 0
def stop(self): def stop(self):
wandb.finish() wandb.finish()
@ -78,3 +110,6 @@ class WandbSummaryWriter(SummaryWriter):
def save_file(self, path, iter=None): def save_file(self, path, iter=None):
wandb.save(path, base_path=os.path.dirname(path)) wandb.save(path, base_path=os.path.dirname(path))
def add_video(self, video_path: str, fps: int = 4, log_name: str = "Video"):
wandb.log({log_name: wandb.Video(video_path, fps=fps)})