Updated version with a queue of action sequences, instead of queue of action
This commit is contained in:
parent
1f59edd5e7
commit
6afdf2f626
|
@ -94,7 +94,9 @@ class ACTPolicy(
|
|||
|
||||
# TODO(rcadene): Add delta timestamps in policy
|
||||
FPS = 10 # noqa: N806
|
||||
self.delta_timestamps = [i / FPS for i in range(self.config.n_action_steps)]
|
||||
self.delta_timestamps = {
|
||||
"action": [i / FPS for i in range(self.config.n_action_steps)],
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
|
@ -103,7 +105,11 @@ class ACTPolicy(
|
|||
else:
|
||||
# TODO(rcadene): set proper maxlen
|
||||
self._obs_queue = TemporalQueue(maxlen=1)
|
||||
self._action_queue = TemporalQueue(maxlen=200)
|
||||
self._action_seq_queue = TemporalQueue(maxlen=200)
|
||||
|
||||
self._action_sequence = None
|
||||
self._action_seq_index = 0
|
||||
self._action_seq_timestamp = None
|
||||
|
||||
@torch.no_grad
|
||||
def inference(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
|
@ -135,13 +141,20 @@ class ACTPolicy(
|
|||
|
||||
if prev_timestamp is not None and prev_timestamp == last_timestamp:
|
||||
# in case inference ran faster than recording/adding a new observation in the queue
|
||||
# print("WAIT INFERENCE")
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
start_t = time.perf_counter()
|
||||
|
||||
pred_action_sequence = self.inference(last_observation)
|
||||
|
||||
for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False):
|
||||
self._action_queue.add(action, last_timestamp + delta_ts)
|
||||
dt_s = time.perf_counter() - start_t
|
||||
print(
|
||||
f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {last_timestamp}"
|
||||
) # , {next_action.mean().item()}")
|
||||
|
||||
self._action_seq_queue.add(pred_action_sequence, last_timestamp)
|
||||
|
||||
prev_timestamp = last_timestamp
|
||||
|
||||
|
@ -155,14 +168,47 @@ class ACTPolicy(
|
|||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
next_action = None
|
||||
while next_action is None:
|
||||
try:
|
||||
next_action = self._action_queue.get(present_time)
|
||||
except ValueError:
|
||||
time.sleep(0.1) # no action available at this present time, we wait a bit
|
||||
while len(self._action_seq_queue) == 0:
|
||||
# print("WAIT")
|
||||
time.sleep(0.1) # no action available at this present time, we wait a bit
|
||||
|
||||
return next_action
|
||||
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
|
||||
|
||||
if self._action_seq_index == len(self.delta_timestamps["action"]):
|
||||
while self._action_seq_timestamp == latest_seq_timestamp:
|
||||
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
|
||||
# print("WAIT")
|
||||
time.sleep(0.1)
|
||||
|
||||
if self._action_seq_timestamp is None:
|
||||
self._action_sequence = latest_action_sequence
|
||||
self._action_seq_timestamp = latest_seq_timestamp
|
||||
|
||||
elif self._action_seq_index == 100 and self._action_seq_timestamp < latest_seq_timestamp:
|
||||
# update sequence index
|
||||
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
|
||||
if self._action_seq_index == len(self.delta_timestamps["action"]):
|
||||
current_timestamp = seq_timestamps[-1]
|
||||
else:
|
||||
current_timestamp = seq_timestamps[self._action_seq_index]
|
||||
|
||||
latest_seq_timestamps = latest_seq_timestamp + np.array(self.delta_timestamps["action"])
|
||||
distances = np.abs(latest_seq_timestamps - current_timestamp)
|
||||
nearest_idx = distances.argmin()
|
||||
# TODO(rcadene): handle edge cases
|
||||
self._action_seq_index = nearest_idx
|
||||
|
||||
# update action sequence
|
||||
self._action_sequence = latest_action_sequence
|
||||
# update inference timestamp (when this action sequence has been computed)
|
||||
self._action_seq_timestamp = latest_seq_timestamp
|
||||
|
||||
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
|
||||
current_timestamp = seq_timestamps[self._action_seq_index]
|
||||
|
||||
action = self._action_sequence[:, self._action_seq_index]
|
||||
self._action_seq_index += 1
|
||||
return action, present_time, current_timestamp
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
|
|
|
@ -1,19 +1,21 @@
|
|||
import platform
|
||||
import time
|
||||
|
||||
|
||||
def busy_wait(seconds):
|
||||
if platform.system() == "Darwin":
|
||||
# On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
|
||||
# but it consumes CPU cycles.
|
||||
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
|
||||
end_time = time.perf_counter() + seconds
|
||||
while time.perf_counter() < end_time:
|
||||
pass
|
||||
else:
|
||||
# On Linux time.sleep is accurate
|
||||
if seconds > 0:
|
||||
time.sleep(seconds)
|
||||
# if platform.system() == "Darwin":
|
||||
# # On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
|
||||
# # but it consumes CPU cycles.
|
||||
# # TODO(rcadene): find an alternative: from python 11, time.sleep is precise
|
||||
# end_time = time.perf_counter() + seconds
|
||||
# while time.perf_counter() < end_time:
|
||||
# pass
|
||||
# else:
|
||||
# # On Linux time.sleep is accurate
|
||||
# if seconds > 0:
|
||||
# time.sleep(seconds)
|
||||
|
||||
if seconds > 0:
|
||||
time.sleep(seconds)
|
||||
|
||||
|
||||
class RobotDeviceNotConnectedError(Exception):
|
||||
|
|
11
test2.py
11
test2.py
|
@ -40,16 +40,19 @@ def main(env_name, policy_name, extra_overrides):
|
|||
|
||||
fps = 30
|
||||
|
||||
for i in range(100):
|
||||
for i in range(200):
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
next_action, timestamp = policy.select_action(obs)
|
||||
next_action, present_time, action_ts = policy.select_action(obs)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
print(f"{i=}, {timestamp}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz)") # , {next_action.mean().item()}")
|
||||
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
print(
|
||||
f"{i=}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) \t{present_time}\t{action_ts}"
|
||||
) # , {next_action.mean().item()}")
|
||||
|
||||
# time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s)
|
||||
# time.sleep(0.5) # frequency at which we receive a new observation (5 Hz = 0.2 s)
|
||||
|
||||
|
|
Loading…
Reference in New Issue