This commit is contained in:
Farbod Farshidian 2024-02-27 13:12:52 -05:00
parent 0a23fff25e
commit 64c2d15b41
2 changed files with 6 additions and 9 deletions

View File

@ -212,7 +212,7 @@ class OnPolicyRunner:
# Video recording for wandb # Video recording for wandb
if self.logger_type == "wandb": if self.logger_type == "wandb":
self.writer.update_video_files(log_name="Video", fps=30.0) self.writer.update_video_files(log_name="Video", fps=30)
str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m " str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m "

View File

@ -19,7 +19,7 @@ class WandbSummaryWriter(SummaryWriter):
"""Summary writer for Weights and Biases.""" """Summary writer for Weights and Biases."""
def __init__(self, log_dir: str, flush_secs: int, cfg): def __init__(self, log_dir: str, flush_secs: int, cfg):
super().__init__(log_dir, flush_secs) super().__init__(log_dir=log_dir, flush_secs=flush_secs)
try: try:
project = cfg["wandb_project"] project = cfg["wandb_project"]
@ -30,7 +30,7 @@ class WandbSummaryWriter(SummaryWriter):
entity = os.environ["WANDB_USERNAME"] entity = os.environ["WANDB_USERNAME"]
except KeyError: except KeyError:
entity = None entity = None
print("Wandb username not found. Using default.") print("`WANDB_USERNAME` is not found! The run will be sent to your username.")
wandb.init(project=project, entity=entity) wandb.init(project=project, entity=entity)
@ -76,7 +76,7 @@ class WandbSummaryWriter(SummaryWriter):
) )
wandb.log({self._map_path(tag): scalar_value}, step=global_step) wandb.log({self._map_path(tag): scalar_value}, step=global_step)
def update_video_files(self, log_name, fps: float): def update_video_files(self, log_name: str, fps: int):
# Check if there are new video files # Check if there are new video files
log_dir = pathlib.Path(self.log_dir) log_dir = pathlib.Path(self.log_dir)
video_files = list(log_dir.rglob("*.mp4")) video_files = list(log_dir.rglob("*.mp4"))
@ -88,10 +88,10 @@ class WandbSummaryWriter(SummaryWriter):
else: else:
# Only upload if the file size is not changing anymore to avoid uploading non-ready video. # Only upload if the file size is not changing anymore to avoid uploading non-ready video.
video_info = self.saved_video_files[str(video_file)] video_info = self.saved_video_files[str(video_file)]
if video_info["size"] == file_size_kb and file_size_kb > 100 and not video_info["added"]: if video_info["added"] is False and video_info["size"] == file_size_kb and file_size_kb > 100:
if video_info["count"] > 10: if video_info["count"] > 10:
print(f"[Wandb] Uploading {os.path.basename(str(video_file))}.") print(f"[Wandb] Uploading {os.path.basename(str(video_file))}.")
self.add_video(str(video_file), fps=fps, log_name=log_name) wandb.log({log_name: wandb.Video(str(video_file), fps=fps)})
self.saved_video_files[str(video_file)]["added"] = True self.saved_video_files[str(video_file)]["added"] = True
else: else:
video_info["count"] += 1 video_info["count"] += 1
@ -110,6 +110,3 @@ class WandbSummaryWriter(SummaryWriter):
def save_file(self, path, iter=None): def save_file(self, path, iter=None):
wandb.save(path, base_path=os.path.dirname(path)) wandb.save(path, base_path=os.path.dirname(path))
def add_video(self, video_path: str, fps: int = 4, log_name: str = "Video"):
wandb.log({log_name: wandb.Video(video_path, fps=fps)})