From 5881eec3768006d481ef3098213c45b9346b96b7 Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 12 Mar 2024 14:14:39 +0000 Subject: [PATCH] Add DEVICE constant from LEROBOT_TESTS_DEVICE --- .github/workflows/test.yml | 3 +++ tests/test_datasets.py | 4 ++-- tests/test_envs.py | 5 +++-- tests/test_policies.py | 3 ++- tests/utils.py | 2 ++ 5 files changed, 12 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 483f235f..46d57945 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,6 +21,9 @@ jobs: TMPDIR: ~/tmp TEMP: ~/tmp TMP: ~/tmp + PYOPENGL_PLATFORM: egl + MUJOCO_GL: egl + LEROBOT_TESTS_DEVICE: cpu steps: #---------------------------------------------- # check-out repo and set-up python diff --git a/tests/test_datasets.py b/tests/test_datasets.py index b3aa7da6..b7d1e6f8 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -3,7 +3,7 @@ import torch from lerobot.common.datasets.factory import make_offline_buffer -from .utils import init_config +from .utils import DEVICE, init_config @pytest.mark.parametrize( @@ -20,7 +20,7 @@ from .utils import init_config ], ) def test_factory(env_name, dataset_id): - cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}"]) + cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]) offline_buffer = make_offline_buffer(cfg) for key in offline_buffer.image_keys: img = offline_buffer[0].get(key) diff --git a/tests/test_envs.py b/tests/test_envs.py index aad13ed8..7776ba3c 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,3 +1,4 @@ +import os import pytest from tensordict import TensorDict import torch @@ -8,7 +9,7 @@ from lerobot.common.envs.factory import make_env from lerobot.common.envs.pusht.env import PushtEnv from lerobot.common.envs.simxarm import SimxarmEnv -from .utils import init_config +from .utils import DEVICE, init_config def print_spec_rollout(env): @@ -89,7 +90,7 @@ def test_pusht(from_pixels, pixels_only): ], ) def test_factory(env_name): - cfg = init_config(overrides=[f"env={env_name}"]) + cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"]) offline_buffer = make_offline_buffer(cfg) diff --git a/tests/test_policies.py b/tests/test_policies.py index 03f20bd0..f00429bc 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -2,7 +2,7 @@ import pytest from lerobot.common.policies.factory import make_policy -from .utils import init_config +from .utils import DEVICE, init_config @pytest.mark.parametrize( @@ -19,6 +19,7 @@ def test_factory(env_name, policy_name): overrides=[ f"env={env_name}", f"policy={policy_name}", + f"device={DEVICE}", ] ) policy = make_policy(cfg) diff --git a/tests/utils.py b/tests/utils.py index 40dc6de0..55709330 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,10 @@ +import os import hydra from hydra import compose, initialize CONFIG_PATH = "../lerobot/configs" +DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda") def init_config(config_name="default", overrides=None): hydra.core.global_hydra.GlobalHydra.instance().clear()