Add DEVICE constant from LEROBOT_TESTS_DEVICE
This commit is contained in:
parent
29c73844b1
commit
5881eec376
|
@ -21,6 +21,9 @@ jobs:
|
||||||
TMPDIR: ~/tmp
|
TMPDIR: ~/tmp
|
||||||
TEMP: ~/tmp
|
TEMP: ~/tmp
|
||||||
TMP: ~/tmp
|
TMP: ~/tmp
|
||||||
|
PYOPENGL_PLATFORM: egl
|
||||||
|
MUJOCO_GL: egl
|
||||||
|
LEROBOT_TESTS_DEVICE: cpu
|
||||||
steps:
|
steps:
|
||||||
#----------------------------------------------
|
#----------------------------------------------
|
||||||
# check-out repo and set-up python
|
# check-out repo and set-up python
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
|
|
||||||
from .utils import init_config
|
from .utils import DEVICE, init_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -20,7 +20,7 @@ from .utils import init_config
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_factory(env_name, dataset_id):
|
def test_factory(env_name, dataset_id):
|
||||||
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}"])
|
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"])
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
for key in offline_buffer.image_keys:
|
for key in offline_buffer.image_keys:
|
||||||
img = offline_buffer[0].get(key)
|
img = offline_buffer[0].get(key)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
import torch
|
import torch
|
||||||
|
@ -8,7 +9,7 @@ from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.pusht.env import PushtEnv
|
from lerobot.common.envs.pusht.env import PushtEnv
|
||||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||||
|
|
||||||
from .utils import init_config
|
from .utils import DEVICE, init_config
|
||||||
|
|
||||||
|
|
||||||
def print_spec_rollout(env):
|
def print_spec_rollout(env):
|
||||||
|
@ -89,7 +90,7 @@ def test_pusht(from_pixels, pixels_only):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_factory(env_name):
|
def test_factory(env_name):
|
||||||
cfg = init_config(overrides=[f"env={env_name}"])
|
cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"])
|
||||||
|
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
||||||
|
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
|
||||||
from .utils import init_config
|
from .utils import DEVICE, init_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -19,6 +19,7 @@ def test_factory(env_name, policy_name):
|
||||||
overrides=[
|
overrides=[
|
||||||
f"env={env_name}",
|
f"env={env_name}",
|
||||||
f"policy={policy_name}",
|
f"policy={policy_name}",
|
||||||
|
f"device={DEVICE}",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
|
import os
|
||||||
import hydra
|
import hydra
|
||||||
from hydra import compose, initialize
|
from hydra import compose, initialize
|
||||||
|
|
||||||
CONFIG_PATH = "../lerobot/configs"
|
CONFIG_PATH = "../lerobot/configs"
|
||||||
|
|
||||||
|
DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda")
|
||||||
|
|
||||||
def init_config(config_name="default", overrides=None):
|
def init_config(config_name="default", overrides=None):
|
||||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||||
|
|
Loading…
Reference in New Issue