fix tests
This commit is contained in:
parent
18146c7419
commit
732814218f
|
@ -6,7 +6,7 @@ import torch
|
|||
from gymnasium.utils.env_checker import check_env
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.factory import make_envs
|
||||
from lerobot.common.envs.utils import preprocess_observations
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
|
@ -37,7 +37,7 @@ def test_factory(env_name):
|
|||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||
)
|
||||
|
||||
env = make_env(cfg, n_envs=1)
|
||||
env = make_envs(cfg, n_envs=1)
|
||||
obs, _ = env.reset()
|
||||
obs = preprocess_observations(obs)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from huggingface_hub import PyTorchModelHubMixin
|
|||
from lerobot import available_policies
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.factory import make_envs
|
||||
from lerobot.common.envs.utils import preprocess_observations
|
||||
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
|
@ -80,7 +80,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
assert isinstance(policy, PyTorchModelHubMixin)
|
||||
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
env = make_env(cfg, n_envs=2)
|
||||
env = make_envs(cfg, n_envs=2)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
|
|
Loading…
Reference in New Issue