Seems to be working

This commit is contained in:
Remi Cadene 2024-10-15 18:02:24 +02:00
parent 6afdf2f626
commit a7841afaa4
2 changed files with 38 additions and 44 deletions

View File

@ -135,28 +135,22 @@ class ACTPolicy(
return actions return actions
def inference_loop(self): def inference_loop(self):
prev_timestamp = None
while not self.stop_event.is_set(): while not self.stop_event.is_set():
last_observation, last_timestamp = self._obs_queue.get_latest() with self.condition:
self.condition.wait()
if prev_timestamp is not None and prev_timestamp == last_timestamp: start_t = time.perf_counter()
# in case inference ran faster than recording/adding a new observation in the queue
# print("WAIT INFERENCE")
time.sleep(0.1)
continue
start_t = time.perf_counter() last_observation, last_timestamp = self._obs_queue.get_latest()
pred_action_sequence = self.inference(last_observation)
self._action_seq_queue.add(pred_action_sequence, last_timestamp)
pred_action_sequence = self.inference(last_observation) dt_s = time.perf_counter() - start_t
print(
f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {last_timestamp}"
) # , {next_action.mean().item()}")
dt_s = time.perf_counter() - start_t self.new_action_seq_event.set()
print(
f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {last_timestamp}"
) # , {next_action.mean().item()}")
self._action_seq_queue.add(pred_action_sequence, last_timestamp)
prev_timestamp = last_timestamp
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
present_time = time.time() present_time = time.time()
@ -164,27 +158,29 @@ class ACTPolicy(
if self.thread is None: if self.thread is None:
self.stop_event = threading.Event() self.stop_event = threading.Event()
self.new_action_seq_event = threading.Event()
self.condition = threading.Condition()
self.thread = Thread(target=self.inference_loop, args=()) self.thread = Thread(target=self.inference_loop, args=())
self.thread.daemon = True self.thread.daemon = True
self.thread.start() self.thread.start()
while len(self._action_seq_queue) == 0: # Ask thread to run first inference
# print("WAIT") with self.condition:
time.sleep(0.1) # no action available at this present time, we wait a bit self.condition.notify()
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest() # Block main process until the thread ran it's first inference
self.new_action_seq_event.wait()
self._action_sequence, self._action_seq_timestamp = self._action_seq_queue.get_latest()
if self._action_seq_index == len(self.delta_timestamps["action"]): if self._action_seq_index == 97:
while self._action_seq_timestamp == latest_seq_timestamp: with self.condition:
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest() self.condition.notify()
# print("WAIT")
time.sleep(0.1)
if self._action_seq_timestamp is None: if self._action_seq_index >= len(self._action_sequence):
self._action_sequence = latest_action_sequence self.new_action_seq_event.wait()
self._action_seq_timestamp = latest_seq_timestamp latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
elif self._action_seq_index == 100 and self._action_seq_timestamp < latest_seq_timestamp:
# update sequence index # update sequence index
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"]) seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
if self._action_seq_index == len(self.delta_timestamps["action"]): if self._action_seq_index == len(self.delta_timestamps["action"]):

View File

@ -1,21 +1,19 @@
import platform
import time import time
def busy_wait(seconds): def busy_wait(seconds):
# if platform.system() == "Darwin": if platform.system() == "Darwin":
# # On Mac, `time.sleep` is not accurate and we need to use this while loop trick, # On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
# # but it consumes CPU cycles. # but it consumes CPU cycles.
# # TODO(rcadene): find an alternative: from python 11, time.sleep is precise # TODO(rcadene): find an alternative: from python 11, time.sleep is precise
# end_time = time.perf_counter() + seconds end_time = time.perf_counter() + seconds
# while time.perf_counter() < end_time: while time.perf_counter() < end_time:
# pass pass
# else: else:
# # On Linux time.sleep is accurate # On Linux time.sleep is accurate
# if seconds > 0: if seconds > 0:
# time.sleep(seconds) time.sleep(seconds)
if seconds > 0:
time.sleep(seconds)
class RobotDeviceNotConnectedError(Exception): class RobotDeviceNotConnectedError(Exception):