From 1f59edd5e7f892c87650c67bdf2ea8e8de87aa54 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Tue, 1 Oct 2024 15:57:09 +0200 Subject: [PATCH] WIP --- lerobot/common/policies/utils.py | 2 +- test2.py | 27 ++++++++++++++++++++------- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index beb1f1d4..7d48c0b8 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -72,7 +72,7 @@ class TemporalQueue: nearest_idx = distances.argmin() # print(float(distances[nearest_idx])) - if float(distances[nearest_idx]) > 1 / 5: + if float(distances[nearest_idx]) > 1 / 30: raise ValueError() return self.items[nearest_idx], self.timestamps[nearest_idx] diff --git a/test2.py b/test2.py index 50f8932f..58975a9d 100644 --- a/test2.py +++ b/test2.py @@ -1,7 +1,10 @@ +import time + import torch from lerobot.common.datasets.factory import make_dataset from lerobot.common.policies.factory import make_policy +from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.utils.utils import init_hydra_config, set_global_seed from tests.utils import DEFAULT_CONFIG_PATH @@ -12,7 +15,7 @@ def main(env_name, policy_name, extra_overrides): overrides=[ f"env={env_name}", f"policy={policy_name}", - "device=cpu", + "device=mps", ] + extra_overrides, ) @@ -31,15 +34,25 @@ def main(env_name, policy_name, extra_overrides): obs = {} for k in batch: if k.startswith("observation"): - obs[k] = batch[k] + obs[k] = batch[k].to("mps") - actions = policy.inference(obs) + # actions = policy.inference(obs) - action, timestamp = policy.select_action(obs) + fps = 30 - print(actions[0]) - print(action) + for i in range(100): + start_loop_t = time.perf_counter() + + next_action, timestamp = policy.select_action(obs) + + dt_s = time.perf_counter() - start_loop_t + print(f"{i=}, {timestamp}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz)") # , {next_action.mean().item()}") + + busy_wait(1 / fps - dt_s) + + # time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s) + # time.sleep(0.5) # frequency at which we receive a new observation (5 Hz = 0.2 s) if __name__ == "__main__": - main("aloha", "act", ["policy.n_action_steps=10"]) + main("aloha", "act", ["policy.n_action_steps=100"])