diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index c05d25c0..a81de49b 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -59,8 +59,6 @@ class AbstractDataset(TensorDictReplayBuffer): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, - # Don't actually load any data. This is a stand-in solution to get the transforms. - dummy: bool = False, ): assert ( self.available_datasets is not None @@ -79,7 +77,7 @@ class AbstractDataset(TensorDictReplayBuffer): f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})." ) - storage = self._download_or_load_dataset() if not dummy else None + storage = self._download_or_load_dataset() super().__init__( storage=storage, diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 83d1581a..031c2cd3 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -97,7 +97,6 @@ class AlohaDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, - dummy: bool = False, ): super().__init__( dataset_id, @@ -111,7 +110,6 @@ class AlohaDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, - dummy=dummy, ) @property diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 4e02f704..04077034 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -21,12 +21,7 @@ def make_offline_buffer( overwrite_batch_size=None, overwrite_prefetch=None, stats_path=None, - # Don't actually load any data. This is a stand-in solution to get the transforms. - dummy=False, ): - if dummy and normalize and stats_path is None: - raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.") - if cfg.policy.balanced_sampling: assert cfg.online_steps > 0 batch_size = None @@ -93,7 +88,6 @@ def make_offline_buffer( root=DATA_DIR, pin_memory=pin_memory, prefetch=prefetch if isinstance(prefetch, int) else None, - dummy=dummy, ) if cfg.policy.name == "tdmpc": diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index d167f3ea..624fb140 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -100,7 +100,6 @@ class PushtDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, - dummy: bool = False, ): super().__init__( dataset_id, @@ -114,7 +113,6 @@ class PushtDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, - dummy=dummy, ) def _download_and_preproc_obsolete(self): diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 06931d3f..dc30e69e 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -51,7 +51,6 @@ class SimxarmDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, - dummy: bool = False, ): super().__init__( dataset_id, @@ -65,7 +64,6 @@ class SimxarmDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, - dummy=dummy, ) def _download_and_preproc_obsolete(self): diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 2a3ab13b..216769d6 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -194,7 +194,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None): logging.info("Making transforms.") # TODO(alexander-soare): Completely decouple datasets from evaluation. - offline_buffer = make_offline_buffer(cfg, stats_path=stats_path, dummy=stats_path is not None) + offline_buffer = make_offline_buffer(cfg, stats_path=stats_path) logging.info("Making environment.") env = make_env(cfg, transform=offline_buffer.transform) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 252e0046..adaefcf5 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -2,8 +2,9 @@ import pytest import torch from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.utils import init_hydra_config -from .utils import DEVICE, init_config +from .utils import DEVICE, DEFAULT_CONFIG_PATH @pytest.mark.parametrize( @@ -18,7 +19,10 @@ from .utils import DEVICE, init_config ], ) def test_factory(env_name, dataset_id): - cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]) + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, + overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"] + ) offline_buffer = make_offline_buffer(cfg) for key in offline_buffer.image_keys: img = offline_buffer[0].get(key) diff --git a/tests/test_envs.py b/tests/test_envs.py index 2beafbda..2bd5e65c 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -7,8 +7,9 @@ from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env from lerobot.common.envs.pusht.env import PushtEnv from lerobot.common.envs.simxarm.env import SimxarmEnv +from lerobot.common.utils import init_hydra_config -from .utils import DEVICE, init_config +from .utils import DEVICE, DEFAULT_CONFIG_PATH def print_spec_rollout(env): @@ -89,7 +90,10 @@ def test_pusht(from_pixels, pixels_only): ], ) def test_factory(env_name): - cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"]) + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, + overrides=[f"env={env_name}", f"device={DEVICE}"], + ) offline_buffer = make_offline_buffer(cfg) diff --git a/tests/test_policies.py b/tests/test_policies.py index d3dc0bc5..5d6b46d0 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,4 +1,3 @@ -from omegaconf import open_dict import pytest from tensordict import TensorDict from tensordict.nn import TensorDictModule @@ -10,8 +9,8 @@ from lerobot.common.policies.factory import make_policy from lerobot.common.envs.factory import make_env from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.policies.abstract import AbstractPolicy - -from .utils import DEVICE, init_config +from lerobot.common.utils import init_hydra_config +from .utils import DEVICE, DEFAULT_CONFIG_PATH @pytest.mark.parametrize( "env_name,policy_name,extra_overrides", @@ -34,7 +33,8 @@ def test_concrete_policy(env_name, policy_name, extra_overrides): - Updating the policy. - Using the policy to select actions at inference time. """ - cfg = init_config( + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, overrides=[ f"env={env_name}", f"policy={policy_name}", diff --git a/tests/utils.py b/tests/utils.py index 55709330..6169c3b6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,13 +1,6 @@ import os -import hydra -from hydra import compose, initialize -CONFIG_PATH = "../lerobot/configs" +# Pass this as the first argument to init_hydra_config. +DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml" DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda") - -def init_config(config_name="default", overrides=None): - hydra.core.global_hydra.GlobalHydra.instance().clear() - initialize(config_path=CONFIG_PATH) - cfg = compose(config_name=config_name, overrides=overrides) - return cfg