unitree_rl_gym/legged_gym/utils/logger.py

39 lines
1.2 KiB
Python
Raw Normal View History

2023-10-11 15:38:49 +08:00
import numpy as np
from collections import defaultdict
from multiprocessing import Process, Value
class Logger:
def __init__(self, dt):
self.state_log = defaultdict(list)
self.rew_log = defaultdict(list)
self.dt = dt
self.num_episodes = 0
self.plot_process = None
def log_state(self, key, value):
self.state_log[key].append(value)
def log_states(self, dict):
for key, value in dict.items():
self.log_state(key, value)
def log_rewards(self, dict, num_episodes):
for key, value in dict.items():
if 'rew' in key:
self.rew_log[key].append(value.item() * num_episodes)
self.num_episodes += num_episodes
def reset(self):
self.state_log.clear()
self.rew_log.clear()
def print_rewards(self):
print("Average rewards per second:")
for key, values in self.rew_log.items():
mean = np.sum(np.array(values)) / self.num_episodes
print(f" - {key}: {mean}")
print(f"Total number of episodes: {self.num_episodes}")
def __del__(self):
if self.plot_process is not None:
self.plot_process.kill()