From b548422076a3ea99e896e6c1ef124a305da90927 Mon Sep 17 00:00:00 2001 From: Nikia Rudin Date: Fri, 29 Oct 2021 18:27:04 +0200 Subject: [PATCH] fixes to runner --- rsl_rl/runners/on_policy_runner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index fe3d930..e11ae4d 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -137,8 +137,9 @@ class OnPolicyRunner: if it % self.save_interval == 0: self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(it))) ep_infos.clear() - self.current_learning_iteration = num_learning_iterations - self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(num_learning_iterations))) + + self.current_learning_iteration += num_learning_iterations + self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration))) def log(self, locs, width=80, pad=35): self.tot_timesteps += self.num_steps_per_env * self.env.num_envs @@ -175,7 +176,7 @@ class OnPolicyRunner: self.writer.add_scalar('Train/mean_reward/time', statistics.mean(locs['rewbuffer']), self.tot_time) self.writer.add_scalar('Train/mean_episode_length/time', statistics.mean(locs['lenbuffer']), self.tot_time) - str = f" \033[1m Learning iteration {locs['it']}/{locs['num_learning_iterations']} \033[0m " + str = f" \033[1m Learning iteration {locs['it']}/{self.current_learning_iteration + locs['num_learning_iterations']} \033[0m " if len(locs['rewbuffer']) > 0: log_string = (f"""{'#' * width}\n"""