diff --git a/test.py b/test.py new file mode 100644 index 00000000..a1ad957c --- /dev/null +++ b/test.py @@ -0,0 +1,86 @@ +import math +import threading +import time +from threading import Thread + + +class TemporalQueue: + def __init__(self): + self.items = [] + self.timestamps = [] + + def add(self, item, timestamp): + self.items.append(item) + self.timestamps.append(timestamp) + + def get(self, timestamp=None): + if timestamp is None: + return self.items[-1], self.timestamps[-1] + + # TODO(rcadene): implement nearest neighbor instead of hacky floor + for idx, t in list(enumerate(self.timestamps))[::-1]: + if math.floor(t) == math.floor(timestamp): + return self.items[idx], t + + raise ValueError() + + def __len__(self): + return len(self.items) + + +class Policy: + def __init__(self): + self.obs_queue = TemporalQueue() + self.action_queue = TemporalQueue() + self.thread = None + + def inference(self, observation): + # TODO + time.sleep(0.5) + return observation + + def inference_loop(self): + previous_timestamp = None + while not self.stop_event.is_set(): + latest_observation, latest_timestamp = self.obs_queue.get() + + if previous_timestamp is not None and previous_timestamp == latest_timestamp: + time.sleep( + 0.1 + ) # in case inference ran faster than recording/adding a new observation in the queue + else: + predicted_action_sequence = self.inference(latest_observation) + self.action_queue.add(predicted_action_sequence, latest_timestamp) + previous_timestamp = latest_timestamp + + def select_action( + self, + new_observation: int, + ) -> list[int]: + present_time = time.time() + self.obs_queue.add(new_observation, present_time) + + if self.thread is None: + self.stop_event = threading.Event() + self.thread = Thread(target=self.inference_loop, args=()) + self.thread.daemon = True + self.thread.start() + + next_action = None + while next_action is None: + try: + next_action = self.action_queue.get(present_time) + except ValueError: + time.sleep(0.1) # no action available at this present time, we wait a bit + + return next_action + + +if __name__ == "__main__": + time.sleep(1) + policy = Policy() + + for new_observation in range(10): + next_action = policy.select_action(new_observation) + print(f"{new_observation=}, {next_action=}") + time.sleep(0.5) # frequency at which we receive a new observation (5 Hz = 0.2 s)