Updated version with a queue of action sequences, instead of queue of action
This commit is contained in:
parent
1f59edd5e7
commit
6afdf2f626
|
@ -94,7 +94,9 @@ class ACTPolicy(
|
||||||
|
|
||||||
# TODO(rcadene): Add delta timestamps in policy
|
# TODO(rcadene): Add delta timestamps in policy
|
||||||
FPS = 10 # noqa: N806
|
FPS = 10 # noqa: N806
|
||||||
self.delta_timestamps = [i / FPS for i in range(self.config.n_action_steps)]
|
self.delta_timestamps = {
|
||||||
|
"action": [i / FPS for i in range(self.config.n_action_steps)],
|
||||||
|
}
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""This should be called whenever the environment is reset."""
|
"""This should be called whenever the environment is reset."""
|
||||||
|
@ -103,7 +105,11 @@ class ACTPolicy(
|
||||||
else:
|
else:
|
||||||
# TODO(rcadene): set proper maxlen
|
# TODO(rcadene): set proper maxlen
|
||||||
self._obs_queue = TemporalQueue(maxlen=1)
|
self._obs_queue = TemporalQueue(maxlen=1)
|
||||||
self._action_queue = TemporalQueue(maxlen=200)
|
self._action_seq_queue = TemporalQueue(maxlen=200)
|
||||||
|
|
||||||
|
self._action_sequence = None
|
||||||
|
self._action_seq_index = 0
|
||||||
|
self._action_seq_timestamp = None
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def inference(self, batch: dict[str, Tensor]) -> Tensor:
|
def inference(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
@ -135,13 +141,20 @@ class ACTPolicy(
|
||||||
|
|
||||||
if prev_timestamp is not None and prev_timestamp == last_timestamp:
|
if prev_timestamp is not None and prev_timestamp == last_timestamp:
|
||||||
# in case inference ran faster than recording/adding a new observation in the queue
|
# in case inference ran faster than recording/adding a new observation in the queue
|
||||||
|
# print("WAIT INFERENCE")
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
start_t = time.perf_counter()
|
||||||
|
|
||||||
pred_action_sequence = self.inference(last_observation)
|
pred_action_sequence = self.inference(last_observation)
|
||||||
|
|
||||||
for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False):
|
dt_s = time.perf_counter() - start_t
|
||||||
self._action_queue.add(action, last_timestamp + delta_ts)
|
print(
|
||||||
|
f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {last_timestamp}"
|
||||||
|
) # , {next_action.mean().item()}")
|
||||||
|
|
||||||
|
self._action_seq_queue.add(pred_action_sequence, last_timestamp)
|
||||||
|
|
||||||
prev_timestamp = last_timestamp
|
prev_timestamp = last_timestamp
|
||||||
|
|
||||||
|
@ -155,14 +168,47 @@ class ACTPolicy(
|
||||||
self.thread.daemon = True
|
self.thread.daemon = True
|
||||||
self.thread.start()
|
self.thread.start()
|
||||||
|
|
||||||
next_action = None
|
while len(self._action_seq_queue) == 0:
|
||||||
while next_action is None:
|
# print("WAIT")
|
||||||
try:
|
time.sleep(0.1) # no action available at this present time, we wait a bit
|
||||||
next_action = self._action_queue.get(present_time)
|
|
||||||
except ValueError:
|
|
||||||
time.sleep(0.1) # no action available at this present time, we wait a bit
|
|
||||||
|
|
||||||
return next_action
|
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
|
||||||
|
|
||||||
|
if self._action_seq_index == len(self.delta_timestamps["action"]):
|
||||||
|
while self._action_seq_timestamp == latest_seq_timestamp:
|
||||||
|
latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest()
|
||||||
|
# print("WAIT")
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
if self._action_seq_timestamp is None:
|
||||||
|
self._action_sequence = latest_action_sequence
|
||||||
|
self._action_seq_timestamp = latest_seq_timestamp
|
||||||
|
|
||||||
|
elif self._action_seq_index == 100 and self._action_seq_timestamp < latest_seq_timestamp:
|
||||||
|
# update sequence index
|
||||||
|
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
|
||||||
|
if self._action_seq_index == len(self.delta_timestamps["action"]):
|
||||||
|
current_timestamp = seq_timestamps[-1]
|
||||||
|
else:
|
||||||
|
current_timestamp = seq_timestamps[self._action_seq_index]
|
||||||
|
|
||||||
|
latest_seq_timestamps = latest_seq_timestamp + np.array(self.delta_timestamps["action"])
|
||||||
|
distances = np.abs(latest_seq_timestamps - current_timestamp)
|
||||||
|
nearest_idx = distances.argmin()
|
||||||
|
# TODO(rcadene): handle edge cases
|
||||||
|
self._action_seq_index = nearest_idx
|
||||||
|
|
||||||
|
# update action sequence
|
||||||
|
self._action_sequence = latest_action_sequence
|
||||||
|
# update inference timestamp (when this action sequence has been computed)
|
||||||
|
self._action_seq_timestamp = latest_seq_timestamp
|
||||||
|
|
||||||
|
seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"])
|
||||||
|
current_timestamp = seq_timestamps[self._action_seq_index]
|
||||||
|
|
||||||
|
action = self._action_sequence[:, self._action_seq_index]
|
||||||
|
self._action_seq_index += 1
|
||||||
|
return action, present_time, current_timestamp
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
|
|
|
@ -1,19 +1,21 @@
|
||||||
import platform
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
def busy_wait(seconds):
|
def busy_wait(seconds):
|
||||||
if platform.system() == "Darwin":
|
# if platform.system() == "Darwin":
|
||||||
# On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
|
# # On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
|
||||||
# but it consumes CPU cycles.
|
# # but it consumes CPU cycles.
|
||||||
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
|
# # TODO(rcadene): find an alternative: from python 11, time.sleep is precise
|
||||||
end_time = time.perf_counter() + seconds
|
# end_time = time.perf_counter() + seconds
|
||||||
while time.perf_counter() < end_time:
|
# while time.perf_counter() < end_time:
|
||||||
pass
|
# pass
|
||||||
else:
|
# else:
|
||||||
# On Linux time.sleep is accurate
|
# # On Linux time.sleep is accurate
|
||||||
if seconds > 0:
|
# if seconds > 0:
|
||||||
time.sleep(seconds)
|
# time.sleep(seconds)
|
||||||
|
|
||||||
|
if seconds > 0:
|
||||||
|
time.sleep(seconds)
|
||||||
|
|
||||||
|
|
||||||
class RobotDeviceNotConnectedError(Exception):
|
class RobotDeviceNotConnectedError(Exception):
|
||||||
|
|
11
test2.py
11
test2.py
|
@ -40,16 +40,19 @@ def main(env_name, policy_name, extra_overrides):
|
||||||
|
|
||||||
fps = 30
|
fps = 30
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(200):
|
||||||
start_loop_t = time.perf_counter()
|
start_loop_t = time.perf_counter()
|
||||||
|
|
||||||
next_action, timestamp = policy.select_action(obs)
|
next_action, present_time, action_ts = policy.select_action(obs)
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
print(f"{i=}, {timestamp}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz)") # , {next_action.mean().item()}")
|
|
||||||
|
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
print(
|
||||||
|
f"{i=}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) \t{present_time}\t{action_ts}"
|
||||||
|
) # , {next_action.mean().item()}")
|
||||||
|
|
||||||
# time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s)
|
# time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s)
|
||||||
# time.sleep(0.5) # frequency at which we receive a new observation (5 Hz = 0.2 s)
|
# time.sleep(0.5) # frequency at which we receive a new observation (5 Hz = 0.2 s)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue