WIP
This commit is contained in:
parent
ec1efc64b4
commit
1f59edd5e7
|
@ -72,7 +72,7 @@ class TemporalQueue:
|
||||||
nearest_idx = distances.argmin()
|
nearest_idx = distances.argmin()
|
||||||
|
|
||||||
# print(float(distances[nearest_idx]))
|
# print(float(distances[nearest_idx]))
|
||||||
if float(distances[nearest_idx]) > 1 / 5:
|
if float(distances[nearest_idx]) > 1 / 30:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
|
||||||
return self.items[nearest_idx], self.timestamps[nearest_idx]
|
return self.items[nearest_idx], self.timestamps[nearest_idx]
|
||||||
|
|
27
test2.py
27
test2.py
|
@ -1,7 +1,10 @@
|
||||||
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
|
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH
|
from tests.utils import DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
@ -12,7 +15,7 @@ def main(env_name, policy_name, extra_overrides):
|
||||||
overrides=[
|
overrides=[
|
||||||
f"env={env_name}",
|
f"env={env_name}",
|
||||||
f"policy={policy_name}",
|
f"policy={policy_name}",
|
||||||
"device=cpu",
|
"device=mps",
|
||||||
]
|
]
|
||||||
+ extra_overrides,
|
+ extra_overrides,
|
||||||
)
|
)
|
||||||
|
@ -31,15 +34,25 @@ def main(env_name, policy_name, extra_overrides):
|
||||||
obs = {}
|
obs = {}
|
||||||
for k in batch:
|
for k in batch:
|
||||||
if k.startswith("observation"):
|
if k.startswith("observation"):
|
||||||
obs[k] = batch[k]
|
obs[k] = batch[k].to("mps")
|
||||||
|
|
||||||
actions = policy.inference(obs)
|
# actions = policy.inference(obs)
|
||||||
|
|
||||||
action, timestamp = policy.select_action(obs)
|
fps = 30
|
||||||
|
|
||||||
print(actions[0])
|
for i in range(100):
|
||||||
print(action)
|
start_loop_t = time.perf_counter()
|
||||||
|
|
||||||
|
next_action, timestamp = policy.select_action(obs)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
print(f"{i=}, {timestamp}, {dt_s * 1000:5.2f} ({1/ dt_s:3.1f}hz)") # , {next_action.mean().item()}")
|
||||||
|
|
||||||
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
|
# time.sleep(1/30) # frequency at which we receive a new observation (30 Hz = 0.03 s)
|
||||||
|
# time.sleep(0.5) # frequency at which we receive a new observation (5 Hz = 0.2 s)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main("aloha", "act", ["policy.n_action_steps=10"])
|
main("aloha", "act", ["policy.n_action_steps=100"])
|
||||||
|
|
Loading…
Reference in New Issue