Add DEVICE constant from LEROBOT_TESTS_DEVICE

This commit is contained in:
Cadene 2024-03-12 14:14:39 +00:00
parent 29c73844b1
commit 5881eec376
5 changed files with 12 additions and 5 deletions

View File

@ -21,6 +21,9 @@ jobs:
TMPDIR: ~/tmp
TEMP: ~/tmp
TMP: ~/tmp
PYOPENGL_PLATFORM: egl
MUJOCO_GL: egl
LEROBOT_TESTS_DEVICE: cpu
steps:
#----------------------------------------------
# check-out repo and set-up python

View File

@ -3,7 +3,7 @@ import torch
from lerobot.common.datasets.factory import make_offline_buffer
from .utils import init_config
from .utils import DEVICE, init_config
@pytest.mark.parametrize(
@ -20,7 +20,7 @@ from .utils import init_config
],
)
def test_factory(env_name, dataset_id):
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}"])
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"])
offline_buffer = make_offline_buffer(cfg)
for key in offline_buffer.image_keys:
img = offline_buffer[0].get(key)

View File

@ -1,3 +1,4 @@
import os
import pytest
from tensordict import TensorDict
import torch
@ -8,7 +9,7 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.envs.pusht.env import PushtEnv
from lerobot.common.envs.simxarm import SimxarmEnv
from .utils import init_config
from .utils import DEVICE, init_config
def print_spec_rollout(env):
@ -89,7 +90,7 @@ def test_pusht(from_pixels, pixels_only):
],
)
def test_factory(env_name):
cfg = init_config(overrides=[f"env={env_name}"])
cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"])
offline_buffer = make_offline_buffer(cfg)

View File

@ -2,7 +2,7 @@ import pytest
from lerobot.common.policies.factory import make_policy
from .utils import init_config
from .utils import DEVICE, init_config
@pytest.mark.parametrize(
@ -19,6 +19,7 @@ def test_factory(env_name, policy_name):
overrides=[
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
]
)
policy = make_policy(cfg)

View File

@ -1,8 +1,10 @@
import os
import hydra
from hydra import compose, initialize
CONFIG_PATH = "../lerobot/configs"
DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda")
def init_config(config_name="default", overrides=None):
hydra.core.global_hydra.GlobalHydra.instance().clear()