Add DEVICE constant from LEROBOT_TESTS_DEVICE
This commit is contained in:
parent
29c73844b1
commit
5881eec376
|
@ -21,6 +21,9 @@ jobs:
|
|||
TMPDIR: ~/tmp
|
||||
TEMP: ~/tmp
|
||||
TMP: ~/tmp
|
||||
PYOPENGL_PLATFORM: egl
|
||||
MUJOCO_GL: egl
|
||||
LEROBOT_TESTS_DEVICE: cpu
|
||||
steps:
|
||||
#----------------------------------------------
|
||||
# check-out repo and set-up python
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
|
||||
from .utils import init_config
|
||||
from .utils import DEVICE, init_config
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -20,7 +20,7 @@ from .utils import init_config
|
|||
],
|
||||
)
|
||||
def test_factory(env_name, dataset_id):
|
||||
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}"])
|
||||
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"])
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
for key in offline_buffer.image_keys:
|
||||
img = offline_buffer[0].get(key)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import pytest
|
||||
from tensordict import TensorDict
|
||||
import torch
|
||||
|
@ -8,7 +9,7 @@ from lerobot.common.envs.factory import make_env
|
|||
from lerobot.common.envs.pusht.env import PushtEnv
|
||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||
|
||||
from .utils import init_config
|
||||
from .utils import DEVICE, init_config
|
||||
|
||||
|
||||
def print_spec_rollout(env):
|
||||
|
@ -89,7 +90,7 @@ def test_pusht(from_pixels, pixels_only):
|
|||
],
|
||||
)
|
||||
def test_factory(env_name):
|
||||
cfg = init_config(overrides=[f"env={env_name}"])
|
||||
cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"])
|
||||
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
|
||||
from .utils import init_config
|
||||
from .utils import DEVICE, init_config
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -19,6 +19,7 @@ def test_factory(env_name, policy_name):
|
|||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
]
|
||||
)
|
||||
policy = make_policy(cfg)
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import os
|
||||
import hydra
|
||||
from hydra import compose, initialize
|
||||
|
||||
CONFIG_PATH = "../lerobot/configs"
|
||||
|
||||
DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda")
|
||||
|
||||
def init_config(config_name="default", overrides=None):
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
|
|
Loading…
Reference in New Issue