Add video recording to wandb logger
This commit is contained in:
parent
e23bb6819a
commit
0a23fff25e
|
@ -210,6 +210,10 @@ class OnPolicyRunner:
|
||||||
"Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time
|
"Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Video recording for wandb
|
||||||
|
if self.logger_type == "wandb":
|
||||||
|
self.writer.update_video_files(log_name="Video", fps=30.0)
|
||||||
|
|
||||||
str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m "
|
str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m "
|
||||||
|
|
||||||
if len(locs["rewbuffer"]) > 0:
|
if len(locs["rewbuffer"]) > 0:
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
import pathlib
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
@ -27,14 +29,15 @@ class WandbSummaryWriter(SummaryWriter):
|
||||||
try:
|
try:
|
||||||
entity = os.environ["WANDB_USERNAME"]
|
entity = os.environ["WANDB_USERNAME"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise KeyError(
|
entity = None
|
||||||
"Wandb username not found. Please run or add to ~/.bashrc: export WANDB_USERNAME=YOUR_USERNAME"
|
print("Wandb username not found. Using default.")
|
||||||
)
|
|
||||||
|
|
||||||
wandb.init(project=project, entity=entity)
|
wandb.init(project=project, entity=entity)
|
||||||
|
|
||||||
# Change generated name to project-number format
|
# Change generated name to project-number format
|
||||||
wandb.run.name = project + wandb.run.name.split("-")[-1]
|
wandb.run.name = project + wandb.run.name.split("-")[-1]
|
||||||
|
with open(os.path.join(log_dir, "wandb_info.json"), "w") as f:
|
||||||
|
json.dump({"wandb_run_id": wandb.run.id, "wandb_run_name": wandb.run.name}, f)
|
||||||
|
|
||||||
self.name_map = {
|
self.name_map = {
|
||||||
"Train/mean_reward/time": "Train/mean_reward_time",
|
"Train/mean_reward/time": "Train/mean_reward_time",
|
||||||
|
@ -45,12 +48,18 @@ class WandbSummaryWriter(SummaryWriter):
|
||||||
|
|
||||||
wandb.log({"log_dir": run_name})
|
wandb.log({"log_dir": run_name})
|
||||||
|
|
||||||
|
# Video logging.
|
||||||
|
self.saved_video_files = {}
|
||||||
|
|
||||||
def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
|
def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
|
||||||
wandb.config.update({"runner_cfg": runner_cfg})
|
wandb.config.update({"runner_cfg": runner_cfg})
|
||||||
wandb.config.update({"policy_cfg": policy_cfg})
|
wandb.config.update({"policy_cfg": policy_cfg})
|
||||||
wandb.config.update({"alg_cfg": alg_cfg})
|
wandb.config.update({"alg_cfg": alg_cfg})
|
||||||
wandb.config.update({"env_cfg": asdict(env_cfg)})
|
wandb.config.update({"env_cfg": asdict(env_cfg)})
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return wandb.config
|
||||||
|
|
||||||
def _map_path(self, path):
|
def _map_path(self, path):
|
||||||
if path in self.name_map:
|
if path in self.name_map:
|
||||||
return self.name_map[path]
|
return self.name_map[path]
|
||||||
|
@ -67,6 +76,29 @@ class WandbSummaryWriter(SummaryWriter):
|
||||||
)
|
)
|
||||||
wandb.log({self._map_path(tag): scalar_value}, step=global_step)
|
wandb.log({self._map_path(tag): scalar_value}, step=global_step)
|
||||||
|
|
||||||
|
def update_video_files(self, log_name, fps: float):
|
||||||
|
# Check if there are new video files
|
||||||
|
log_dir = pathlib.Path(self.log_dir)
|
||||||
|
video_files = list(log_dir.rglob("*.mp4"))
|
||||||
|
for video_file in video_files:
|
||||||
|
file_size_kb = os.stat(str(video_file)).st_size / 1024
|
||||||
|
# If it is new file
|
||||||
|
if str(video_file) not in self.saved_video_files:
|
||||||
|
self.saved_video_files[str(video_file)] = {"size": file_size_kb, "added": False, "count": 0}
|
||||||
|
else:
|
||||||
|
# Only upload if the file size is not changing anymore to avoid uploading non-ready video.
|
||||||
|
video_info = self.saved_video_files[str(video_file)]
|
||||||
|
if video_info["size"] == file_size_kb and file_size_kb > 100 and not video_info["added"]:
|
||||||
|
if video_info["count"] > 10:
|
||||||
|
print(f"[Wandb] Uploading {os.path.basename(str(video_file))}.")
|
||||||
|
self.add_video(str(video_file), fps=fps, log_name=log_name)
|
||||||
|
self.saved_video_files[str(video_file)]["added"] = True
|
||||||
|
else:
|
||||||
|
video_info["count"] += 1
|
||||||
|
else:
|
||||||
|
self.saved_video_files[str(video_file)]["size"] = file_size_kb
|
||||||
|
video_info["count"] = 0
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
wandb.finish()
|
wandb.finish()
|
||||||
|
|
||||||
|
@ -78,3 +110,6 @@ class WandbSummaryWriter(SummaryWriter):
|
||||||
|
|
||||||
def save_file(self, path, iter=None):
|
def save_file(self, path, iter=None):
|
||||||
wandb.save(path, base_path=os.path.dirname(path))
|
wandb.save(path, base_path=os.path.dirname(path))
|
||||||
|
|
||||||
|
def add_video(self, video_path: str, fps: int = 4, log_name: str = "Video"):
|
||||||
|
wandb.log({log_name: wandb.Video(video_path, fps=fps)})
|
||||||
|
|
Loading…
Reference in New Issue