revision
This commit is contained in:
parent
120f0aef5c
commit
b7c9c33072
|
@ -59,8 +59,6 @@ class AbstractDataset(TensorDictReplayBuffer):
|
|||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
# Don't actually load any data. This is a stand-in solution to get the transforms.
|
||||
dummy: bool = False,
|
||||
):
|
||||
assert (
|
||||
self.available_datasets is not None
|
||||
|
@ -79,7 +77,7 @@ class AbstractDataset(TensorDictReplayBuffer):
|
|||
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
|
||||
)
|
||||
|
||||
storage = self._download_or_load_dataset() if not dummy else None
|
||||
storage = self._download_or_load_dataset()
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
|
|
|
@ -97,7 +97,6 @@ class AlohaDataset(AbstractDataset):
|
|||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
dummy: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
|
@ -111,7 +110,6 @@ class AlohaDataset(AbstractDataset):
|
|||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
@ -21,12 +21,7 @@ def make_offline_buffer(
|
|||
overwrite_batch_size=None,
|
||||
overwrite_prefetch=None,
|
||||
stats_path=None,
|
||||
# Don't actually load any data. This is a stand-in solution to get the transforms.
|
||||
dummy=False,
|
||||
):
|
||||
if dummy and normalize and stats_path is None:
|
||||
raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.")
|
||||
|
||||
if cfg.policy.balanced_sampling:
|
||||
assert cfg.online_steps > 0
|
||||
batch_size = None
|
||||
|
@ -93,7 +88,6 @@ def make_offline_buffer(
|
|||
root=DATA_DIR,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
|
|
|
@ -100,7 +100,6 @@ class PushtDataset(AbstractDataset):
|
|||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
dummy: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
|
@ -114,7 +113,6 @@ class PushtDataset(AbstractDataset):
|
|||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
|
|
|
@ -51,7 +51,6 @@ class SimxarmDataset(AbstractDataset):
|
|||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
dummy: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
|
@ -65,7 +64,6 @@ class SimxarmDataset(AbstractDataset):
|
|||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
dummy=dummy,
|
||||
)
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
|
|
|
@ -194,7 +194,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
|||
|
||||
logging.info("Making transforms.")
|
||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path, dummy=stats_path is not None)
|
||||
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
|
|
|
@ -2,8 +2,9 @@ import pytest
|
|||
import torch
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
|
||||
from .utils import DEVICE, init_config
|
||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -18,7 +19,10 @@ from .utils import DEVICE, init_config
|
|||
],
|
||||
)
|
||||
def test_factory(env_name, dataset_id):
|
||||
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"])
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]
|
||||
)
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
for key in offline_buffer.image_keys:
|
||||
img = offline_buffer[0].get(key)
|
||||
|
|
|
@ -7,8 +7,9 @@ from lerobot.common.datasets.factory import make_offline_buffer
|
|||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.pusht.env import PushtEnv
|
||||
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
|
||||
from .utils import DEVICE, init_config
|
||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
def print_spec_rollout(env):
|
||||
|
@ -89,7 +90,10 @@ def test_pusht(from_pixels, pixels_only):
|
|||
],
|
||||
)
|
||||
def test_factory(env_name):
|
||||
cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"])
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||
)
|
||||
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from omegaconf import open_dict
|
||||
import pytest
|
||||
from tensordict import TensorDict
|
||||
from tensordict.nn import TensorDictModule
|
||||
|
@ -10,8 +9,8 @@ from lerobot.common.policies.factory import make_policy
|
|||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
|
||||
from .utils import DEVICE, init_config
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
|
@ -34,7 +33,8 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
|
|||
- Updating the policy.
|
||||
- Using the policy to select actions at inference time.
|
||||
"""
|
||||
cfg = init_config(
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
|
|
|
@ -1,13 +1,6 @@
|
|||
import os
|
||||
import hydra
|
||||
from hydra import compose, initialize
|
||||
|
||||
CONFIG_PATH = "../lerobot/configs"
|
||||
# Pass this as the first argument to init_hydra_config.
|
||||
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||
|
||||
DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda")
|
||||
|
||||
def init_config(config_name="default", overrides=None):
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=CONFIG_PATH)
|
||||
cfg = compose(config_name=config_name, overrides=overrides)
|
||||
return cfg
|
||||
|
|
Loading…
Reference in New Issue