diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 08f91b6..937b36d 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -212,7 +212,7 @@ class OnPolicyRunner: # Video recording for wandb if self.logger_type == "wandb": - self.writer.update_video_files(log_name="Video", fps=30.0) + self.writer.update_video_files(log_name="Video", fps=30) str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m " diff --git a/rsl_rl/utils/wandb_utils.py b/rsl_rl/utils/wandb_utils.py index aff69f3..e14661e 100644 --- a/rsl_rl/utils/wandb_utils.py +++ b/rsl_rl/utils/wandb_utils.py @@ -19,7 +19,7 @@ class WandbSummaryWriter(SummaryWriter): """Summary writer for Weights and Biases.""" def __init__(self, log_dir: str, flush_secs: int, cfg): - super().__init__(log_dir, flush_secs) + super().__init__(log_dir=log_dir, flush_secs=flush_secs) try: project = cfg["wandb_project"] @@ -30,7 +30,7 @@ class WandbSummaryWriter(SummaryWriter): entity = os.environ["WANDB_USERNAME"] except KeyError: entity = None - print("Wandb username not found. Using default.") + print("`WANDB_USERNAME` is not found! The run will be sent to your username.") wandb.init(project=project, entity=entity) @@ -76,7 +76,7 @@ class WandbSummaryWriter(SummaryWriter): ) wandb.log({self._map_path(tag): scalar_value}, step=global_step) - def update_video_files(self, log_name, fps: float): + def update_video_files(self, log_name: str, fps: int): # Check if there are new video files log_dir = pathlib.Path(self.log_dir) video_files = list(log_dir.rglob("*.mp4")) @@ -88,10 +88,10 @@ class WandbSummaryWriter(SummaryWriter): else: # Only upload if the file size is not changing anymore to avoid uploading non-ready video. video_info = self.saved_video_files[str(video_file)] - if video_info["size"] == file_size_kb and file_size_kb > 100 and not video_info["added"]: + if video_info["added"] is False and video_info["size"] == file_size_kb and file_size_kb > 100: if video_info["count"] > 10: print(f"[Wandb] Uploading {os.path.basename(str(video_file))}.") - self.add_video(str(video_file), fps=fps, log_name=log_name) + wandb.log({log_name: wandb.Video(str(video_file), fps=fps)}) self.saved_video_files[str(video_file)]["added"] = True else: video_info["count"] += 1 @@ -110,6 +110,3 @@ class WandbSummaryWriter(SummaryWriter): def save_file(self, path, iter=None): wandb.save(path, base_path=os.path.dirname(path)) - - def add_video(self, video_path: str, fps: int = 4, log_name: str = "Video"): - wandb.log({log_name: wandb.Video(video_path, fps=fps)})