Seems to be working
This commit is contained in:
parent
6afdf2f626
commit
a7841afaa4
|
@ -135,28 +135,22 @@ class ACTPolicy(
|
|||
return actions
|
||||
|
||||
def inference_loop(self):
|
||||
prev_timestamp = None
|
||||
while not self.stop_event.is_set():
|
||||
last_observation, last_timestamp = self._obs_queue.get_latest()
|
||||
|
||||
if prev_timestamp is not None and prev_timestamp == last_timestamp:
|
||||
# in case inference ran faster than recording/adding a new observation in the queue
|
||||
# print("WAIT INFERENCE")
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
with self.condition:
|
||||
self.condition.wait()
|
||||
|
||||
start_t = time.perf_counter()
|
||||
|
||||
last_observation, last_timestamp = self._obs_queue.get_latest()
|
||||
pred_action_sequence = self.inference(last_observation)
|
||||
self._action_seq_queue.add(pred_action_sequence, last_timestamp)
|
||||
|
||||
dt_s = time.perf_counter() - start_t
|
||||
print(
|
||||
f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {last_timestamp}"
|
||||
) # , {next_action.mean().item()}")
|
||||
|
||||
self._action_seq_queue.add(pred_action_sequence, last_timestamp)
|
||||
|
||||
prev_timestamp = last_timestamp
|
||||
self.new_action_seq_event.set()
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
present_time = time.time()
|
||||
|
@ -164,27 +158,29 @@ class ACTPolicy(
|
|||
|
||||
if self.thread is None:
|
||||
self.stop_event = threading.Event()
|
||||
self.new_action_seq_event = threading.Event()
|
||||
|
||||
self.condition = threading.Condition()
|
||||
self.thread = Thread(target=self.inference_loop, args=())
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
while len(self._action_seq_queue) == 0:
|
||||
# print("WAIT")
|
||||
time.sleep(0.1) # no action available at this present time, we wait a bit
|
||||
# Ask thread to run first inference
|
||||
with self.condition:
|
||||
self.condition.notify()
|
||||
|
||||
# Block main process until the thread ran it's first inference
|
||||
self.new_action_seq_event.wait()
|
||||
self._action_sequence, self._action_seq_timestamp = self._action_seq_queue.get_latest()
|
||||
|
||||
if self._action_seq_index == 97:
|
||||
with self.condition:
|
||||
self.condition.notify()
|
||||
|
||||
if self._action_seq_index >= len(self._action_sequence):
|
||||
self.new_action_seq_event.wait()
|
||||
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
|
||||
|
||||
if self._action_seq_index == len(self.delta_timestamps["action"]):
|
||||
while self._action_seq_timestamp == latest_seq_timestamp:
|
||||
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
|
||||
# print("WAIT")
|
||||
time.sleep(0.1)
|
||||
|
||||
if self._action_seq_timestamp is None:
|
||||
self._action_sequence = latest_action_sequence
|
||||
self._action_seq_timestamp = latest_seq_timestamp
|
||||
|
||||
elif self._action_seq_index == 100 and self._action_seq_timestamp < latest_seq_timestamp:
|
||||
# update sequence index
|
||||
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
|
||||
if self._action_seq_index == len(self.delta_timestamps["action"]):
|
||||
|
|
|
@ -1,19 +1,17 @@
|
|||
import platform
|
||||
import time
|
||||
|
||||
|
||||
def busy_wait(seconds):
|
||||
# if platform.system() == "Darwin":
|
||||
# # On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
|
||||
# # but it consumes CPU cycles.
|
||||
# # TODO(rcadene): find an alternative: from python 11, time.sleep is precise
|
||||
# end_time = time.perf_counter() + seconds
|
||||
# while time.perf_counter() < end_time:
|
||||
# pass
|
||||
# else:
|
||||
# # On Linux time.sleep is accurate
|
||||
# if seconds > 0:
|
||||
# time.sleep(seconds)
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
# On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
|
||||
# but it consumes CPU cycles.
|
||||
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
|
||||
end_time = time.perf_counter() + seconds
|
||||
while time.perf_counter() < end_time:
|
||||
pass
|
||||
else:
|
||||
# On Linux time.sleep is accurate
|
||||
if seconds > 0:
|
||||
time.sleep(seconds)
|
||||
|
||||
|
|
Loading…
Reference in New Issue