diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 6ae0837..08f91b6 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -210,6 +210,10 @@ class OnPolicyRunner: "Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time ) + # Video recording for wandb + if self.logger_type == "wandb": + self.writer.update_video_files(log_name="Video", fps=30.0) + str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m " if len(locs["rewbuffer"]) > 0: diff --git a/rsl_rl/utils/wandb_utils.py b/rsl_rl/utils/wandb_utils.py index 2868ce9..aff69f3 100644 --- a/rsl_rl/utils/wandb_utils.py +++ b/rsl_rl/utils/wandb_utils.py @@ -4,6 +4,8 @@ from __future__ import annotations import os +import json +import pathlib from dataclasses import asdict from torch.utils.tensorboard import SummaryWriter @@ -27,14 +29,15 @@ class WandbSummaryWriter(SummaryWriter): try: entity = os.environ["WANDB_USERNAME"] except KeyError: - raise KeyError( - "Wandb username not found. Please run or add to ~/.bashrc: export WANDB_USERNAME=YOUR_USERNAME" - ) + entity = None + print("Wandb username not found. Using default.") wandb.init(project=project, entity=entity) # Change generated name to project-number format wandb.run.name = project + wandb.run.name.split("-")[-1] + with open(os.path.join(log_dir, "wandb_info.json"), "w") as f: + json.dump({"wandb_run_id": wandb.run.id, "wandb_run_name": wandb.run.name}, f) self.name_map = { "Train/mean_reward/time": "Train/mean_reward_time", @@ -45,12 +48,18 @@ class WandbSummaryWriter(SummaryWriter): wandb.log({"log_dir": run_name}) + # Video logging. + self.saved_video_files = {} + def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): wandb.config.update({"runner_cfg": runner_cfg}) wandb.config.update({"policy_cfg": policy_cfg}) wandb.config.update({"alg_cfg": alg_cfg}) wandb.config.update({"env_cfg": asdict(env_cfg)}) + def get_config(self): + return wandb.config + def _map_path(self, path): if path in self.name_map: return self.name_map[path] @@ -67,6 +76,29 @@ class WandbSummaryWriter(SummaryWriter): ) wandb.log({self._map_path(tag): scalar_value}, step=global_step) + def update_video_files(self, log_name, fps: float): + # Check if there are new video files + log_dir = pathlib.Path(self.log_dir) + video_files = list(log_dir.rglob("*.mp4")) + for video_file in video_files: + file_size_kb = os.stat(str(video_file)).st_size / 1024 + # If it is new file + if str(video_file) not in self.saved_video_files: + self.saved_video_files[str(video_file)] = {"size": file_size_kb, "added": False, "count": 0} + else: + # Only upload if the file size is not changing anymore to avoid uploading non-ready video. + video_info = self.saved_video_files[str(video_file)] + if video_info["size"] == file_size_kb and file_size_kb > 100 and not video_info["added"]: + if video_info["count"] > 10: + print(f"[Wandb] Uploading {os.path.basename(str(video_file))}.") + self.add_video(str(video_file), fps=fps, log_name=log_name) + self.saved_video_files[str(video_file)]["added"] = True + else: + video_info["count"] += 1 + else: + self.saved_video_files[str(video_file)]["size"] = file_size_kb + video_info["count"] = 0 + def stop(self): wandb.finish() @@ -78,3 +110,6 @@ class WandbSummaryWriter(SummaryWriter): def save_file(self, path, iter=None): wandb.save(path, base_path=os.path.dirname(path)) + + def add_video(self, video_path: str, fps: int = 4, log_name: str = "Video"): + wandb.log({log_name: wandb.Video(video_path, fps=fps)})