Simplify
This commit is contained in:
parent
0a23fff25e
commit
64c2d15b41
|
@ -212,7 +212,7 @@ class OnPolicyRunner:
|
||||||
|
|
||||||
# Video recording for wandb
|
# Video recording for wandb
|
||||||
if self.logger_type == "wandb":
|
if self.logger_type == "wandb":
|
||||||
self.writer.update_video_files(log_name="Video", fps=30.0)
|
self.writer.update_video_files(log_name="Video", fps=30)
|
||||||
|
|
||||||
str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m "
|
str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m "
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ class WandbSummaryWriter(SummaryWriter):
|
||||||
"""Summary writer for Weights and Biases."""
|
"""Summary writer for Weights and Biases."""
|
||||||
|
|
||||||
def __init__(self, log_dir: str, flush_secs: int, cfg):
|
def __init__(self, log_dir: str, flush_secs: int, cfg):
|
||||||
super().__init__(log_dir, flush_secs)
|
super().__init__(log_dir=log_dir, flush_secs=flush_secs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
project = cfg["wandb_project"]
|
project = cfg["wandb_project"]
|
||||||
|
@ -30,7 +30,7 @@ class WandbSummaryWriter(SummaryWriter):
|
||||||
entity = os.environ["WANDB_USERNAME"]
|
entity = os.environ["WANDB_USERNAME"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
entity = None
|
entity = None
|
||||||
print("Wandb username not found. Using default.")
|
print("`WANDB_USERNAME` is not found! The run will be sent to your username.")
|
||||||
|
|
||||||
wandb.init(project=project, entity=entity)
|
wandb.init(project=project, entity=entity)
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ class WandbSummaryWriter(SummaryWriter):
|
||||||
)
|
)
|
||||||
wandb.log({self._map_path(tag): scalar_value}, step=global_step)
|
wandb.log({self._map_path(tag): scalar_value}, step=global_step)
|
||||||
|
|
||||||
def update_video_files(self, log_name, fps: float):
|
def update_video_files(self, log_name: str, fps: int):
|
||||||
# Check if there are new video files
|
# Check if there are new video files
|
||||||
log_dir = pathlib.Path(self.log_dir)
|
log_dir = pathlib.Path(self.log_dir)
|
||||||
video_files = list(log_dir.rglob("*.mp4"))
|
video_files = list(log_dir.rglob("*.mp4"))
|
||||||
|
@ -88,10 +88,10 @@ class WandbSummaryWriter(SummaryWriter):
|
||||||
else:
|
else:
|
||||||
# Only upload if the file size is not changing anymore to avoid uploading non-ready video.
|
# Only upload if the file size is not changing anymore to avoid uploading non-ready video.
|
||||||
video_info = self.saved_video_files[str(video_file)]
|
video_info = self.saved_video_files[str(video_file)]
|
||||||
if video_info["size"] == file_size_kb and file_size_kb > 100 and not video_info["added"]:
|
if video_info["added"] is False and video_info["size"] == file_size_kb and file_size_kb > 100:
|
||||||
if video_info["count"] > 10:
|
if video_info["count"] > 10:
|
||||||
print(f"[Wandb] Uploading {os.path.basename(str(video_file))}.")
|
print(f"[Wandb] Uploading {os.path.basename(str(video_file))}.")
|
||||||
self.add_video(str(video_file), fps=fps, log_name=log_name)
|
wandb.log({log_name: wandb.Video(str(video_file), fps=fps)})
|
||||||
self.saved_video_files[str(video_file)]["added"] = True
|
self.saved_video_files[str(video_file)]["added"] = True
|
||||||
else:
|
else:
|
||||||
video_info["count"] += 1
|
video_info["count"] += 1
|
||||||
|
@ -110,6 +110,3 @@ class WandbSummaryWriter(SummaryWriter):
|
||||||
|
|
||||||
def save_file(self, path, iter=None):
|
def save_file(self, path, iter=None):
|
||||||
wandb.save(path, base_path=os.path.dirname(path))
|
wandb.save(path, base_path=os.path.dirname(path))
|
||||||
|
|
||||||
def add_video(self, video_path: str, fps: int = 4, log_name: str = "Video"):
|
|
||||||
wandb.log({log_name: wandb.Video(video_path, fps=fps)})
|
|
||||||
|
|
Loading…
Reference in New Issue