2024-02-25 18:50:23 +08:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
from lerobot.common.policies.factory import make_policy
|
|
|
|
|
2024-03-12 22:14:39 +08:00
|
|
|
from .utils import DEVICE, init_config
|
2024-02-25 18:50:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
2024-02-26 09:10:09 +08:00
|
|
|
"env_name,policy_name",
|
2024-02-25 18:50:23 +08:00
|
|
|
[
|
2024-02-26 09:10:09 +08:00
|
|
|
("simxarm", "tdmpc"),
|
|
|
|
("pusht", "tdmpc"),
|
|
|
|
("simxarm", "diffusion"),
|
|
|
|
("pusht", "diffusion"),
|
2024-02-25 18:50:23 +08:00
|
|
|
],
|
|
|
|
)
|
2024-02-26 09:10:09 +08:00
|
|
|
def test_factory(env_name, policy_name):
|
|
|
|
cfg = init_config(
|
|
|
|
overrides=[
|
|
|
|
f"env={env_name}",
|
|
|
|
f"policy={policy_name}",
|
2024-03-12 22:14:39 +08:00
|
|
|
f"device={DEVICE}",
|
2024-02-26 09:10:09 +08:00
|
|
|
]
|
|
|
|
)
|
2024-02-25 18:50:23 +08:00
|
|
|
policy = make_policy(cfg)
|