diff --git a/rsl_rl/utils/utils.py b/rsl_rl/utils/utils.py index 4a9abaa..88dd741 100644 --- a/rsl_rl/utils/utils.py +++ b/rsl_rl/utils/utils.py @@ -59,6 +59,8 @@ def unpad_trajectories(trajectories, masks): def store_code_state(logdir, repositories) -> list: + git_log_dir = os.path.join(logdir, "git") + os.makedirs(git_log_dir, exist_ok=True) file_paths = [] for repository_file_path in repositories: try: @@ -69,7 +71,7 @@ def store_code_state(logdir, repositories) -> list: # get the name of the repository repo_name = pathlib.Path(repo.working_dir).name t = repo.head.commit.tree - diff_file_name = os.path.join(logdir, f"{repo_name}_git.diff") + diff_file_name = os.path.join(git_log_dir, f"{repo_name}.diff") content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}" with open(diff_file_name, "x") as f: f.write(content) diff --git a/rsl_rl/utils/wandb_utils.py b/rsl_rl/utils/wandb_utils.py index 630f870..2868ce9 100644 --- a/rsl_rl/utils/wandb_utils.py +++ b/rsl_rl/utils/wandb_utils.py @@ -74,7 +74,7 @@ class WandbSummaryWriter(SummaryWriter): self.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg) def save_model(self, model_path, iter): - wandb.save(model_path) + wandb.save(model_path, base_path=os.path.dirname(model_path)) def save_file(self, path, iter=None): - wandb.save(path) + wandb.save(path, base_path=os.path.dirname(path))