This commit is contained in:
Cadene 2024-03-23 11:41:56 +00:00
parent a80d9c0257
commit 40f3783fca
8 changed files with 5 additions and 2 deletions

View File

@ -84,7 +84,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
def __init__( def __init__(
self, self,
dataset_id: str, dataset_id: str,
version: str | None = "v1.1", version: str | None = "v1.2",
batch_size: int = None, batch_size: int = None,
*, *,
shuffle: bool = True, shuffle: bool = True,

View File

@ -87,7 +87,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
def __init__( def __init__(
self, self,
dataset_id: str, dataset_id: str,
version: str | None = "v1.1", version: str | None = "v1.2",
batch_size: int = None, batch_size: int = None,
*, *,
shuffle: bool = True, shuffle: bool = True,

Binary file not shown.

View File

@ -22,6 +22,9 @@ from .utils import DEVICE, init_config
("simxarm", "diffusion", []), ("simxarm", "diffusion", []),
("pusht", "diffusion", []), ("pusht", "diffusion", []),
("aloha", "act", ["env.task=sim_insertion_scripted"]), ("aloha", "act", ["env.task=sim_insertion_scripted"]),
("aloha", "act", ["env.task=sim_insertion_human"]),
("aloha", "act", ["env.task=sim_transfer_cube_scripted"]),
("aloha", "act", ["env.task=sim_transfer_cube_human"]),
], ],
) )
def test_concrete_policy(env_name, policy_name, extra_overrides): def test_concrete_policy(env_name, policy_name, extra_overrides):