This commit is contained in:
Alexander Soare 2024-03-27 18:33:48 +00:00
parent 120f0aef5c
commit b7c9c33072
10 changed files with 20 additions and 33 deletions

View File

@ -59,8 +59,6 @@ class AbstractDataset(TensorDictReplayBuffer):
collate_fn: Callable | None = None, collate_fn: Callable | None = None,
writer: Writer | None = None, writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None, transform: "torchrl.envs.Transform" = None,
# Don't actually load any data. This is a stand-in solution to get the transforms.
dummy: bool = False,
): ):
assert ( assert (
self.available_datasets is not None self.available_datasets is not None
@ -79,7 +77,7 @@ class AbstractDataset(TensorDictReplayBuffer):
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})." f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
) )
storage = self._download_or_load_dataset() if not dummy else None storage = self._download_or_load_dataset()
super().__init__( super().__init__(
storage=storage, storage=storage,

View File

@ -97,7 +97,6 @@ class AlohaDataset(AbstractDataset):
collate_fn: Callable | None = None, collate_fn: Callable | None = None,
writer: Writer | None = None, writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None, transform: "torchrl.envs.Transform" = None,
dummy: bool = False,
): ):
super().__init__( super().__init__(
dataset_id, dataset_id,
@ -111,7 +110,6 @@ class AlohaDataset(AbstractDataset):
collate_fn=collate_fn, collate_fn=collate_fn,
writer=writer, writer=writer,
transform=transform, transform=transform,
dummy=dummy,
) )
@property @property

View File

@ -21,12 +21,7 @@ def make_offline_buffer(
overwrite_batch_size=None, overwrite_batch_size=None,
overwrite_prefetch=None, overwrite_prefetch=None,
stats_path=None, stats_path=None,
# Don't actually load any data. This is a stand-in solution to get the transforms.
dummy=False,
): ):
if dummy and normalize and stats_path is None:
raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.")
if cfg.policy.balanced_sampling: if cfg.policy.balanced_sampling:
assert cfg.online_steps > 0 assert cfg.online_steps > 0
batch_size = None batch_size = None
@ -93,7 +88,6 @@ def make_offline_buffer(
root=DATA_DIR, root=DATA_DIR,
pin_memory=pin_memory, pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None, prefetch=prefetch if isinstance(prefetch, int) else None,
dummy=dummy,
) )
if cfg.policy.name == "tdmpc": if cfg.policy.name == "tdmpc":

View File

@ -100,7 +100,6 @@ class PushtDataset(AbstractDataset):
collate_fn: Callable | None = None, collate_fn: Callable | None = None,
writer: Writer | None = None, writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None, transform: "torchrl.envs.Transform" = None,
dummy: bool = False,
): ):
super().__init__( super().__init__(
dataset_id, dataset_id,
@ -114,7 +113,6 @@ class PushtDataset(AbstractDataset):
collate_fn=collate_fn, collate_fn=collate_fn,
writer=writer, writer=writer,
transform=transform, transform=transform,
dummy=dummy,
) )
def _download_and_preproc_obsolete(self): def _download_and_preproc_obsolete(self):

View File

@ -51,7 +51,6 @@ class SimxarmDataset(AbstractDataset):
collate_fn: Callable | None = None, collate_fn: Callable | None = None,
writer: Writer | None = None, writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None, transform: "torchrl.envs.Transform" = None,
dummy: bool = False,
): ):
super().__init__( super().__init__(
dataset_id, dataset_id,
@ -65,7 +64,6 @@ class SimxarmDataset(AbstractDataset):
collate_fn=collate_fn, collate_fn=collate_fn,
writer=writer, writer=writer,
transform=transform, transform=transform,
dummy=dummy,
) )
def _download_and_preproc_obsolete(self): def _download_and_preproc_obsolete(self):

View File

@ -194,7 +194,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making transforms.") logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation. # TODO(alexander-soare): Completely decouple datasets from evaluation.
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path, dummy=stats_path is not None) offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
logging.info("Making environment.") logging.info("Making environment.")
env = make_env(cfg, transform=offline_buffer.transform) env = make_env(cfg, transform=offline_buffer.transform)

View File

@ -2,8 +2,9 @@ import pytest
import torch import torch
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, init_config from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -18,7 +19,10 @@ from .utils import DEVICE, init_config
], ],
) )
def test_factory(env_name, dataset_id): def test_factory(env_name, dataset_id):
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]) cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]
)
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)
for key in offline_buffer.image_keys: for key in offline_buffer.image_keys:
img = offline_buffer[0].get(key) img = offline_buffer[0].get(key)

View File

@ -7,8 +7,9 @@ from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.pusht.env import PushtEnv from lerobot.common.envs.pusht.env import PushtEnv
from lerobot.common.envs.simxarm.env import SimxarmEnv from lerobot.common.envs.simxarm.env import SimxarmEnv
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, init_config from .utils import DEVICE, DEFAULT_CONFIG_PATH
def print_spec_rollout(env): def print_spec_rollout(env):
@ -89,7 +90,10 @@ def test_pusht(from_pixels, pixels_only):
], ],
) )
def test_factory(env_name): def test_factory(env_name):
cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"]) cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[f"env={env_name}", f"device={DEVICE}"],
)
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)

View File

@ -1,4 +1,3 @@
from omegaconf import open_dict
import pytest import pytest
from tensordict import TensorDict from tensordict import TensorDict
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
@ -10,8 +9,8 @@ from lerobot.common.policies.factory import make_policy
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, init_config from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name,policy_name,extra_overrides", "env_name,policy_name,extra_overrides",
@ -34,7 +33,8 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
- Updating the policy. - Updating the policy.
- Using the policy to select actions at inference time. - Using the policy to select actions at inference time.
""" """
cfg = init_config( cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[ overrides=[
f"env={env_name}", f"env={env_name}",
f"policy={policy_name}", f"policy={policy_name}",

View File

@ -1,13 +1,6 @@
import os import os
import hydra
from hydra import compose, initialize
CONFIG_PATH = "../lerobot/configs" # Pass this as the first argument to init_hydra_config.
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda") DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda")
def init_config(config_name="default", overrides=None):
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=CONFIG_PATH)
cfg = compose(config_name=config_name, overrides=overrides)
return cfg