This commit is contained in:
Remi Cadene 2024-10-01 15:57:09 +02:00
parent ec1efc64b4
commit 1f59edd5e7
2 changed files with 21 additions and 8 deletions

View File

@ -72,7 +72,7 @@ class TemporalQueue:
nearest_idx = distances.argmin()
# print(float(distances[nearest_idx]))
if float(distances[nearest_idx]) > 1 / 5:
if float(distances[nearest_idx]) > 1 / 30:
raise ValueError()
return self.items[nearest_idx], self.timestamps[nearest_idx]

View File

@ -1,7 +1,10 @@
import time
import torch
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
from tests.utils import DEFAULT_CONFIG_PATH
@ -12,7 +15,7 @@ def main(env_name, policy_name, extra_overrides):
overrides=[
f"env={env_name}",
f"policy={policy_name}",
"device=cpu",
"device=mps",
]
+ extra_overrides,
)
@ -31,15 +34,25 @@ def main(env_name, policy_name, extra_overrides):
obs = {}
for k in batch:
if k.startswith("observation"):
obs[k] = batch[k]
obs[k] = batch[k].to("mps")
actions = policy.inference(obs)
# actions = policy.inference(obs)
action, timestamp = policy.select_action(obs)
fps = 30
print(actions[0])
print(action)
for i in range(100):
start_loop_t = time.perf_counter()
next_action, timestamp = policy.select_action(obs)
dt_s = time.perf_counter() - start_loop_t
print(f"{i=}, {timestamp}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz)") # , {next_action.mean().item()}")
busy_wait(1 / fps - dt_s)
# time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s)
# time.sleep(0.5) # frequency at which we receive a new observation (5 Hz = 0.2 s)
if __name__ == "__main__":
main("aloha", "act", ["policy.n_action_steps=10"])
main("aloha", "act", ["policy.n_action_steps=100"])