Updated version with a queue of action sequences, instead of queue of action

This commit is contained in:
Remi Cadene 2024-10-14 12:23:29 +02:00
parent 1f59edd5e7
commit 6afdf2f626
3 changed files with 78 additions and 27 deletions

View File

@ -94,7 +94,9 @@ class ACTPolicy(
# TODO(rcadene): Add delta timestamps in policy
FPS = 10 # noqa: N806
self.delta_timestamps = [i / FPS for i in range(self.config.n_action_steps)]
self.delta_timestamps = {
"action": [i / FPS for i in range(self.config.n_action_steps)],
}
def reset(self):
"""This should be called whenever the environment is reset."""
@ -103,7 +105,11 @@ class ACTPolicy(
else:
# TODO(rcadene): set proper maxlen
self._obs_queue = TemporalQueue(maxlen=1)
self._action_queue = TemporalQueue(maxlen=200)
self._action_seq_queue = TemporalQueue(maxlen=200)
self._action_sequence = None
self._action_seq_index = 0
self._action_seq_timestamp = None
@torch.no_grad
def inference(self, batch: dict[str, Tensor]) -> Tensor:
@ -135,13 +141,20 @@ class ACTPolicy(
if prev_timestamp is not None and prev_timestamp == last_timestamp:
# in case inference ran faster than recording/adding a new observation in the queue
# print("WAIT INFERENCE")
time.sleep(0.1)
continue
start_t = time.perf_counter()
pred_action_sequence = self.inference(last_observation)
for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False):
self._action_queue.add(action, last_timestamp + delta_ts)
dt_s = time.perf_counter() - start_t
print(
f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {last_timestamp}"
) # , {next_action.mean().item()}")
self._action_seq_queue.add(pred_action_sequence, last_timestamp)
prev_timestamp = last_timestamp
@ -155,14 +168,47 @@ class ACTPolicy(
self.thread.daemon = True
self.thread.start()
next_action = None
while next_action is None:
try:
next_action = self._action_queue.get(present_time)
except ValueError:
while len(self._action_seq_queue) == 0:
# print("WAIT")
time.sleep(0.1) # no action available at this present time, we wait a bit
return next_action
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
if self._action_seq_index == len(self.delta_timestamps["action"]):
while self._action_seq_timestamp == latest_seq_timestamp:
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
# print("WAIT")
time.sleep(0.1)
if self._action_seq_timestamp is None:
self._action_sequence = latest_action_sequence
self._action_seq_timestamp = latest_seq_timestamp
elif self._action_seq_index == 100 and self._action_seq_timestamp < latest_seq_timestamp:
# update sequence index
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
if self._action_seq_index == len(self.delta_timestamps["action"]):
current_timestamp = seq_timestamps[-1]
else:
current_timestamp = seq_timestamps[self._action_seq_index]
latest_seq_timestamps = latest_seq_timestamp + np.array(self.delta_timestamps["action"])
distances = np.abs(latest_seq_timestamps - current_timestamp)
nearest_idx = distances.argmin()
# TODO(rcadene): handle edge cases
self._action_seq_index = nearest_idx
# update action sequence
self._action_sequence = latest_action_sequence
# update inference timestamp (when this action sequence has been computed)
self._action_seq_timestamp = latest_seq_timestamp
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
current_timestamp = seq_timestamps[self._action_seq_index]
action = self._action_sequence[:, self._action_seq_index]
self._action_seq_index += 1
return action, present_time, current_timestamp
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""

View File

@ -1,17 +1,19 @@
import platform
import time
def busy_wait(seconds):
if platform.system() == "Darwin":
# On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
# but it consumes CPU cycles.
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
end_time = time.perf_counter() + seconds
while time.perf_counter() < end_time:
pass
else:
# On Linux time.sleep is accurate
# if platform.system() == "Darwin":
# # On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
# # but it consumes CPU cycles.
# # TODO(rcadene): find an alternative: from python 11, time.sleep is precise
# end_time = time.perf_counter() + seconds
# while time.perf_counter() < end_time:
# pass
# else:
# # On Linux time.sleep is accurate
# if seconds > 0:
# time.sleep(seconds)
if seconds > 0:
time.sleep(seconds)

View File

@ -40,16 +40,19 @@ def main(env_name, policy_name, extra_overrides):
fps = 30
for i in range(100):
for i in range(200):
start_loop_t = time.perf_counter()
next_action, timestamp = policy.select_action(obs)
next_action, present_time, action_ts = policy.select_action(obs)
dt_s = time.perf_counter() - start_loop_t
print(f"{i=}, {timestamp}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz)") # , {next_action.mean().item()}")
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
print(
f"{i=}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) \t{present_time}\t{action_ts}"
) # , {next_action.mean().item()}")
# time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s)
# time.sleep(0.5) # frequency at which we receive a new observation (5 Hz = 0.2 s)