Add video recording to wandb logger
This commit is contained in:
parent
e23bb6819a
commit
0a23fff25e
|
@ -210,6 +210,10 @@ class OnPolicyRunner:
|
|||
"Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time
|
||||
)
|
||||
|
||||
# Video recording for wandb
|
||||
if self.logger_type == "wandb":
|
||||
self.writer.update_video_files(log_name="Video", fps=30.0)
|
||||
|
||||
str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m "
|
||||
|
||||
if len(locs["rewbuffer"]) > 0:
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import json
|
||||
import pathlib
|
||||
from dataclasses import asdict
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
@ -27,14 +29,15 @@ class WandbSummaryWriter(SummaryWriter):
|
|||
try:
|
||||
entity = os.environ["WANDB_USERNAME"]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
"Wandb username not found. Please run or add to ~/.bashrc: export WANDB_USERNAME=YOUR_USERNAME"
|
||||
)
|
||||
entity = None
|
||||
print("Wandb username not found. Using default.")
|
||||
|
||||
wandb.init(project=project, entity=entity)
|
||||
|
||||
# Change generated name to project-number format
|
||||
wandb.run.name = project + wandb.run.name.split("-")[-1]
|
||||
with open(os.path.join(log_dir, "wandb_info.json"), "w") as f:
|
||||
json.dump({"wandb_run_id": wandb.run.id, "wandb_run_name": wandb.run.name}, f)
|
||||
|
||||
self.name_map = {
|
||||
"Train/mean_reward/time": "Train/mean_reward_time",
|
||||
|
@ -45,12 +48,18 @@ class WandbSummaryWriter(SummaryWriter):
|
|||
|
||||
wandb.log({"log_dir": run_name})
|
||||
|
||||
# Video logging.
|
||||
self.saved_video_files = {}
|
||||
|
||||
def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
|
||||
wandb.config.update({"runner_cfg": runner_cfg})
|
||||
wandb.config.update({"policy_cfg": policy_cfg})
|
||||
wandb.config.update({"alg_cfg": alg_cfg})
|
||||
wandb.config.update({"env_cfg": asdict(env_cfg)})
|
||||
|
||||
def get_config(self):
|
||||
return wandb.config
|
||||
|
||||
def _map_path(self, path):
|
||||
if path in self.name_map:
|
||||
return self.name_map[path]
|
||||
|
@ -67,6 +76,29 @@ class WandbSummaryWriter(SummaryWriter):
|
|||
)
|
||||
wandb.log({self._map_path(tag): scalar_value}, step=global_step)
|
||||
|
||||
def update_video_files(self, log_name, fps: float):
|
||||
# Check if there are new video files
|
||||
log_dir = pathlib.Path(self.log_dir)
|
||||
video_files = list(log_dir.rglob("*.mp4"))
|
||||
for video_file in video_files:
|
||||
file_size_kb = os.stat(str(video_file)).st_size / 1024
|
||||
# If it is new file
|
||||
if str(video_file) not in self.saved_video_files:
|
||||
self.saved_video_files[str(video_file)] = {"size": file_size_kb, "added": False, "count": 0}
|
||||
else:
|
||||
# Only upload if the file size is not changing anymore to avoid uploading non-ready video.
|
||||
video_info = self.saved_video_files[str(video_file)]
|
||||
if video_info["size"] == file_size_kb and file_size_kb > 100 and not video_info["added"]:
|
||||
if video_info["count"] > 10:
|
||||
print(f"[Wandb] Uploading {os.path.basename(str(video_file))}.")
|
||||
self.add_video(str(video_file), fps=fps, log_name=log_name)
|
||||
self.saved_video_files[str(video_file)]["added"] = True
|
||||
else:
|
||||
video_info["count"] += 1
|
||||
else:
|
||||
self.saved_video_files[str(video_file)]["size"] = file_size_kb
|
||||
video_info["count"] = 0
|
||||
|
||||
def stop(self):
|
||||
wandb.finish()
|
||||
|
||||
|
@ -78,3 +110,6 @@ class WandbSummaryWriter(SummaryWriter):
|
|||
|
||||
def save_file(self, path, iter=None):
|
||||
wandb.save(path, base_path=os.path.dirname(path))
|
||||
|
||||
def add_video(self, video_path: str, fps: int = 4, log_name: str = "Video"):
|
||||
wandb.log({log_name: wandb.Video(video_path, fps=fps)})
|
||||
|
|
Loading…
Reference in New Issue