91 lines
2.2 KiB
Python
91 lines
2.2 KiB
Python
from collections import defaultdict
|
|
|
|
from ml_logger import logger
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
class DistCache:
|
|
def __init__(self):
|
|
"""
|
|
Args:
|
|
n: Number of slots for the cache
|
|
"""
|
|
self.cache = defaultdict(lambda: 0)
|
|
|
|
def log(self, **key_vals):
|
|
"""
|
|
Args:
|
|
slots: ids for the array
|
|
**key_vals:
|
|
"""
|
|
for k, v in key_vals.items():
|
|
count = self.cache[k + '@counts'] + 1
|
|
self.cache[k + '@counts'] = count
|
|
self.cache[k] = v + (count - 1) * self.cache[k]
|
|
self.cache[k] /= count
|
|
|
|
def get_summary(self):
|
|
ret = {
|
|
k: v
|
|
for k, v in self.cache.items()
|
|
if not k.endswith("@counts")
|
|
}
|
|
self.cache.clear()
|
|
return ret
|
|
|
|
|
|
if __name__ == '__main__':
|
|
cl = DistCache()
|
|
lin_vel = np.ones((11, 11))
|
|
ang_vel = np.zeros((5, 5))
|
|
cl.log(lin_vel=lin_vel, ang_vel=ang_vel)
|
|
lin_vel = np.zeros((11, 11))
|
|
ang_vel = np.zeros((5, 5))
|
|
cl.log(lin_vel=lin_vel, ang_vel=ang_vel)
|
|
print(cl.get_summary())
|
|
|
|
|
|
class SlotCache:
|
|
def __init__(self, n):
|
|
"""
|
|
Args:
|
|
n: Number of slots for the cache
|
|
"""
|
|
self.n = n
|
|
self.cache = defaultdict(lambda: np.zeros([n]))
|
|
|
|
def log(self, slots=None, **key_vals):
|
|
"""
|
|
Args:
|
|
slots: ids for the array
|
|
**key_vals:
|
|
"""
|
|
if slots is None:
|
|
slots = range(self.n)
|
|
|
|
for k, v in key_vals.items():
|
|
counts = self.cache[k + '@counts'][slots] + 1
|
|
self.cache[k + '@counts'][slots] = counts
|
|
self.cache[k][slots] = v + (counts - 1) * self.cache[k][slots]
|
|
self.cache[k][slots] /= counts
|
|
|
|
def get_summary(self):
|
|
ret = {
|
|
k: v
|
|
for k, v in self.cache.items()
|
|
if not k.endswith("@counts")
|
|
}
|
|
self.cache.clear()
|
|
return ret
|
|
|
|
|
|
if __name__ == '__main__':
|
|
cl = SlotCache(100)
|
|
reset_env_ids = [2, 5, 6]
|
|
lin_vel = [0.1, 0.5, 0.8]
|
|
ang_vel = [0.4, -0.4, 0.2]
|
|
cl.log(reset_env_ids, lin_vel=lin_vel, ang_vel=ang_vel)
|
|
|
|
cl.log(lin_vel=np.ones(100))
|