2024-02-26 01:42:47 +08:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
from lerobot.common.datasets.factory import make_offline_buffer
|
|
|
|
|
|
|
|
from .utils import init_config
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
2024-03-06 21:55:12 +08:00
|
|
|
"env_name,dataset_id",
|
2024-02-26 01:42:47 +08:00
|
|
|
[
|
2024-03-06 21:55:12 +08:00
|
|
|
# TODO(rcadene): simxarm is depreciated for now
|
|
|
|
# ("simxarm", "lift"),
|
|
|
|
("pusht", "pusht"),
|
2024-03-07 22:57:38 +08:00
|
|
|
# TODO(aliberts): add aloha when dataset is available on hub
|
|
|
|
# ("aloha", "sim_insertion_human"),
|
|
|
|
# ("aloha", "sim_insertion_scripted"),
|
|
|
|
# ("aloha", "sim_transfer_cube_human"),
|
|
|
|
# ("aloha", "sim_transfer_cube_scripted"),
|
2024-02-26 01:42:47 +08:00
|
|
|
],
|
|
|
|
)
|
2024-03-06 21:55:12 +08:00
|
|
|
def test_factory(env_name, dataset_id):
|
|
|
|
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}"])
|
2024-02-26 01:42:47 +08:00
|
|
|
offline_buffer = make_offline_buffer(cfg)
|