diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index be75396f..45a0cdb4 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -94,7 +94,9 @@ class ACTPolicy( # TODO(rcadene): Add delta timestamps in policy FPS = 10 # noqa: N806 - self.delta_timestamps = [i / FPS for i in range(self.config.n_action_steps)] + self.delta_timestamps = { + "action": [i / FPS for i in range(self.config.n_action_steps)], + } def reset(self): """This should be called whenever the environment is reset.""" @@ -103,7 +105,11 @@ class ACTPolicy( else: # TODO(rcadene): set proper maxlen self._obs_queue = TemporalQueue(maxlen=1) - self._action_queue = TemporalQueue(maxlen=200) + self._action_seq_queue = TemporalQueue(maxlen=200) + + self._action_sequence = None + self._action_seq_index = 0 + self._action_seq_timestamp = None @torch.no_grad def inference(self, batch: dict[str, Tensor]) -> Tensor: @@ -135,13 +141,20 @@ class ACTPolicy( if prev_timestamp is not None and prev_timestamp == last_timestamp: # in case inference ran faster than recording/adding a new observation in the queue + # print("WAIT INFERENCE") time.sleep(0.1) continue + start_t = time.perf_counter() + pred_action_sequence = self.inference(last_observation) - for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False): - self._action_queue.add(action, last_timestamp + delta_ts) + dt_s = time.perf_counter() - start_t + print( + f"Inference, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) -- {last_timestamp}" + ) # , {next_action.mean().item()}") + + self._action_seq_queue.add(pred_action_sequence, last_timestamp) prev_timestamp = last_timestamp @@ -155,14 +168,47 @@ class ACTPolicy( self.thread.daemon = True self.thread.start() - next_action = None - while next_action is None: - try: - next_action = self._action_queue.get(present_time) - except ValueError: - time.sleep(0.1) # no action available at this present time, we wait a bit + while len(self._action_seq_queue) == 0: + # print("WAIT") + time.sleep(0.1) # no action available at this present time, we wait a bit - return next_action + latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest() + + if self._action_seq_index == len(self.delta_timestamps["action"]): + while self._action_seq_timestamp == latest_seq_timestamp: + latest_action_sequence, latest_seq_timestamp = self._action_seq_queue.get_latest() + # print("WAIT") + time.sleep(0.1) + + if self._action_seq_timestamp is None: + self._action_sequence = latest_action_sequence + self._action_seq_timestamp = latest_seq_timestamp + + elif self._action_seq_index == 100 and self._action_seq_timestamp < latest_seq_timestamp: + # update sequence index + seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"]) + if self._action_seq_index == len(self.delta_timestamps["action"]): + current_timestamp = seq_timestamps[-1] + else: + current_timestamp = seq_timestamps[self._action_seq_index] + + latest_seq_timestamps = latest_seq_timestamp + np.array(self.delta_timestamps["action"]) + distances = np.abs(latest_seq_timestamps - current_timestamp) + nearest_idx = distances.argmin() + # TODO(rcadene): handle edge cases + self._action_seq_index = nearest_idx + + # update action sequence + self._action_sequence = latest_action_sequence + # update inference timestamp (when this action sequence has been computed) + self._action_seq_timestamp = latest_seq_timestamp + + seq_timestamps = self._action_seq_timestamp + np.array(self.delta_timestamps["action"]) + current_timestamp = seq_timestamps[self._action_seq_index] + + action = self._action_sequence[:, self._action_seq_index] + self._action_seq_index += 1 + return action, present_time, current_timestamp def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" diff --git a/lerobot/common/robot_devices/utils.py b/lerobot/common/robot_devices/utils.py index bcbeb8e0..cabde8a4 100644 --- a/lerobot/common/robot_devices/utils.py +++ b/lerobot/common/robot_devices/utils.py @@ -1,19 +1,21 @@ -import platform import time def busy_wait(seconds): - if platform.system() == "Darwin": - # On Mac, `time.sleep` is not accurate and we need to use this while loop trick, - # but it consumes CPU cycles. - # TODO(rcadene): find an alternative: from python 11, time.sleep is precise - end_time = time.perf_counter() + seconds - while time.perf_counter() < end_time: - pass - else: - # On Linux time.sleep is accurate - if seconds > 0: - time.sleep(seconds) + # if platform.system() == "Darwin": + # # On Mac, `time.sleep` is not accurate and we need to use this while loop trick, + # # but it consumes CPU cycles. + # # TODO(rcadene): find an alternative: from python 11, time.sleep is precise + # end_time = time.perf_counter() + seconds + # while time.perf_counter() < end_time: + # pass + # else: + # # On Linux time.sleep is accurate + # if seconds > 0: + # time.sleep(seconds) + + if seconds > 0: + time.sleep(seconds) class RobotDeviceNotConnectedError(Exception): diff --git a/test2.py b/test2.py index 58975a9d..eae8210b 100644 --- a/test2.py +++ b/test2.py @@ -40,16 +40,19 @@ def main(env_name, policy_name, extra_overrides): fps = 30 - for i in range(100): + for i in range(200): start_loop_t = time.perf_counter() - next_action, timestamp = policy.select_action(obs) + next_action, present_time, action_ts = policy.select_action(obs) dt_s = time.perf_counter() - start_loop_t - print(f"{i=}, {timestamp}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz)") # , {next_action.mean().item()}") - busy_wait(1 / fps - dt_s) + dt_s = time.perf_counter() - start_loop_t + print( + f"{i=}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz) \t{present_time}\t{action_ts}" + ) # , {next_action.mean().item()}") + # time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s) # time.sleep(0.5) # frequency at which we receive a new observation (5 Hz = 0.2 s)