2024-09-18 06:35:20 +08:00
|
|
|
import math
|
|
|
|
import threading
|
|
|
|
import time
|
|
|
|
from threading import Thread
|
|
|
|
|
|
|
|
|
|
|
|
class TemporalQueue:
|
|
|
|
def __init__(self):
|
|
|
|
self.items = []
|
|
|
|
self.timestamps = []
|
|
|
|
|
|
|
|
def add(self, item, timestamp):
|
|
|
|
self.items.append(item)
|
|
|
|
self.timestamps.append(timestamp)
|
|
|
|
|
|
|
|
def get(self, timestamp=None):
|
|
|
|
if timestamp is None:
|
|
|
|
return self.items[-1], self.timestamps[-1]
|
|
|
|
|
|
|
|
# TODO(rcadene): implement nearest neighbor instead of hacky floor
|
|
|
|
for idx, t in list(enumerate(self.timestamps))[::-1]:
|
|
|
|
if math.floor(t) == math.floor(timestamp):
|
|
|
|
return self.items[idx], t
|
|
|
|
|
|
|
|
raise ValueError()
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.items)
|
|
|
|
|
|
|
|
|
|
|
|
class Policy:
|
|
|
|
def __init__(self):
|
|
|
|
self.obs_queue = TemporalQueue()
|
|
|
|
self.action_queue = TemporalQueue()
|
|
|
|
self.thread = None
|
|
|
|
|
2024-09-18 06:53:43 +08:00
|
|
|
self.n_action = 2
|
|
|
|
FPS = 10 # noqa: N806
|
|
|
|
self.delta_timestamps = [i / FPS for i in range(self.n_action)]
|
|
|
|
|
2024-09-18 06:35:20 +08:00
|
|
|
def inference(self, observation):
|
|
|
|
# TODO
|
|
|
|
time.sleep(0.5)
|
2024-09-18 06:53:43 +08:00
|
|
|
return [observation] * self.n_action
|
2024-09-18 06:35:20 +08:00
|
|
|
|
|
|
|
def inference_loop(self):
|
2024-09-18 06:53:43 +08:00
|
|
|
prev_timestamp = None
|
2024-09-18 06:35:20 +08:00
|
|
|
while not self.stop_event.is_set():
|
2024-09-18 06:53:43 +08:00
|
|
|
last_observation, last_timestamp = self.obs_queue.get()
|
2024-09-18 06:35:20 +08:00
|
|
|
|
2024-09-18 06:53:43 +08:00
|
|
|
if prev_timestamp is not None and prev_timestamp == last_timestamp:
|
2024-09-18 06:36:08 +08:00
|
|
|
# in case inference ran faster than recording/adding a new observation in the queue
|
|
|
|
time.sleep(0.1)
|
2024-09-18 06:53:43 +08:00
|
|
|
continue
|
|
|
|
|
|
|
|
pred_action_sequence = self.inference(last_observation)
|
|
|
|
|
|
|
|
for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False):
|
|
|
|
self.action_queue.add(action, last_timestamp + delta_ts)
|
|
|
|
|
|
|
|
prev_timestamp = last_timestamp
|
2024-09-18 06:35:20 +08:00
|
|
|
|
|
|
|
def select_action(
|
|
|
|
self,
|
|
|
|
new_observation: int,
|
|
|
|
) -> list[int]:
|
|
|
|
present_time = time.time()
|
|
|
|
self.obs_queue.add(new_observation, present_time)
|
|
|
|
|
|
|
|
if self.thread is None:
|
|
|
|
self.stop_event = threading.Event()
|
|
|
|
self.thread = Thread(target=self.inference_loop, args=())
|
|
|
|
self.thread.daemon = True
|
|
|
|
self.thread.start()
|
|
|
|
|
|
|
|
next_action = None
|
|
|
|
while next_action is None:
|
|
|
|
try:
|
|
|
|
next_action = self.action_queue.get(present_time)
|
|
|
|
except ValueError:
|
|
|
|
time.sleep(0.1) # no action available at this present time, we wait a bit
|
|
|
|
|
|
|
|
return next_action
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
time.sleep(1)
|
|
|
|
policy = Policy()
|
|
|
|
|
|
|
|
for new_observation in range(10):
|
|
|
|
next_action = policy.select_action(new_observation)
|
|
|
|
print(f"{new_observation=}, {next_action=}")
|
|
|
|
time.sleep(0.5) # frequency at which we receive a new observation (5 Hz = 0.2 s)
|