This commit is contained in:
Remi Cadene 2024-09-18 00:35:20 +02:00
parent 72f402d44b
commit 850e6d540d
1 changed files with 86 additions and 0 deletions

86
test.py Normal file
View File

@ -0,0 +1,86 @@
import math
import threading
import time
from threading import Thread
class TemporalQueue:
def __init__(self):
self.items = []
self.timestamps = []
def add(self, item, timestamp):
self.items.append(item)
self.timestamps.append(timestamp)
def get(self, timestamp=None):
if timestamp is None:
return self.items[-1], self.timestamps[-1]
# TODO(rcadene): implement nearest neighbor instead of hacky floor
for idx, t in list(enumerate(self.timestamps))[::-1]:
if math.floor(t) == math.floor(timestamp):
return self.items[idx], t
raise ValueError()
def __len__(self):
return len(self.items)
class Policy:
def __init__(self):
self.obs_queue = TemporalQueue()
self.action_queue = TemporalQueue()
self.thread = None
def inference(self, observation):
# TODO
time.sleep(0.5)
return observation
def inference_loop(self):
previous_timestamp = None
while not self.stop_event.is_set():
latest_observation, latest_timestamp = self.obs_queue.get()
if previous_timestamp is not None and previous_timestamp == latest_timestamp:
time.sleep(
0.1
) # in case inference ran faster than recording/adding a new observation in the queue
else:
predicted_action_sequence = self.inference(latest_observation)
self.action_queue.add(predicted_action_sequence, latest_timestamp)
previous_timestamp = latest_timestamp
def select_action(
self,
new_observation: int,
) -> list[int]:
present_time = time.time()
self.obs_queue.add(new_observation, present_time)
if self.thread is None:
self.stop_event = threading.Event()
self.thread = Thread(target=self.inference_loop, args=())
self.thread.daemon = True
self.thread.start()
next_action = None
while next_action is None:
try:
next_action = self.action_queue.get(present_time)
except ValueError:
time.sleep(0.1) # no action available at this present time, we wait a bit
return next_action
if __name__ == "__main__":
time.sleep(1)
policy = Policy()
for new_observation in range(10):
next_action = policy.select_action(new_observation)
print(f"{new_observation=}, {next_action=}")
time.sleep(0.5) # frequency at which we receive a new observation (5 Hz = 0.2 s)