import numpy as np from collections import defaultdict from multiprocessing import Process, Value class Logger: def __init__(self, dt): self.state_log = defaultdict(list) self.rew_log = defaultdict(list) self.dt = dt self.num_episodes = 0 self.plot_process = None def log_state(self, key, value): self.state_log[key].append(value) def log_states(self, dict): for key, value in dict.items(): self.log_state(key, value) def log_rewards(self, dict, num_episodes): for key, value in dict.items(): if 'rew' in key: self.rew_log[key].append(value.item() * num_episodes) self.num_episodes += num_episodes def reset(self): self.state_log.clear() self.rew_log.clear() def print_rewards(self): print("Average rewards per second:") for key, values in self.rew_log.items(): mean = np.sum(np.array(values)) / self.num_episodes print(f" - {key}: {mean}") print(f"Total number of episodes: {self.num_episodes}") def __del__(self): if self.plot_process is not None: self.plot_process.kill()