fixes to runner

This commit is contained in:
Nikia Rudin 2021-10-29 18:27:04 +02:00 committed by Nikita Rudin
parent ff9e971c97
commit b548422076
1 changed files with 4 additions and 3 deletions

View File

@ -137,8 +137,9 @@ class OnPolicyRunner:
if it % self.save_interval == 0:
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(it)))
ep_infos.clear()
self.current_learning_iteration = num_learning_iterations
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(num_learning_iterations)))
self.current_learning_iteration += num_learning_iterations
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
def log(self, locs, width=80, pad=35):
self.tot_timesteps += self.num_steps_per_env * self.env.num_envs
@ -175,7 +176,7 @@ class OnPolicyRunner:
self.writer.add_scalar('Train/mean_reward/time', statistics.mean(locs['rewbuffer']), self.tot_time)
self.writer.add_scalar('Train/mean_episode_length/time', statistics.mean(locs['lenbuffer']), self.tot_time)
str = f" \033[1m Learning iteration {locs['it']}/{locs['num_learning_iterations']} \033[0m "
str = f" \033[1m Learning iteration {locs['it']}/{self.current_learning_iteration + locs['num_learning_iterations']} \033[0m "
if len(locs['rewbuffer']) > 0:
log_string = (f"""{'#' * width}\n"""