39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
import numpy as np
|
|
from collections import defaultdict
|
|
from multiprocessing import Process, Value
|
|
|
|
class Logger:
|
|
def __init__(self, dt):
|
|
self.state_log = defaultdict(list)
|
|
self.rew_log = defaultdict(list)
|
|
self.dt = dt
|
|
self.num_episodes = 0
|
|
self.plot_process = None
|
|
|
|
def log_state(self, key, value):
|
|
self.state_log[key].append(value)
|
|
|
|
def log_states(self, dict):
|
|
for key, value in dict.items():
|
|
self.log_state(key, value)
|
|
|
|
def log_rewards(self, dict, num_episodes):
|
|
for key, value in dict.items():
|
|
if 'rew' in key:
|
|
self.rew_log[key].append(value.item() * num_episodes)
|
|
self.num_episodes += num_episodes
|
|
|
|
def reset(self):
|
|
self.state_log.clear()
|
|
self.rew_log.clear()
|
|
|
|
def print_rewards(self):
|
|
print("Average rewards per second:")
|
|
for key, values in self.rew_log.items():
|
|
mean = np.sum(np.array(values)) / self.num_episodes
|
|
print(f" - {key}: {mean}")
|
|
print(f"Total number of episodes: {self.num_episodes}")
|
|
|
|
def __del__(self):
|
|
if self.plot_process is not None:
|
|
self.plot_process.kill() |