revision
This commit is contained in:
parent
120f0aef5c
commit
b7c9c33072
|
@ -59,8 +59,6 @@ class AbstractDataset(TensorDictReplayBuffer):
|
||||||
collate_fn: Callable | None = None,
|
collate_fn: Callable | None = None,
|
||||||
writer: Writer | None = None,
|
writer: Writer | None = None,
|
||||||
transform: "torchrl.envs.Transform" = None,
|
transform: "torchrl.envs.Transform" = None,
|
||||||
# Don't actually load any data. This is a stand-in solution to get the transforms.
|
|
||||||
dummy: bool = False,
|
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
self.available_datasets is not None
|
self.available_datasets is not None
|
||||||
|
@ -79,7 +77,7 @@ class AbstractDataset(TensorDictReplayBuffer):
|
||||||
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
|
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
|
||||||
)
|
)
|
||||||
|
|
||||||
storage = self._download_or_load_dataset() if not dummy else None
|
storage = self._download_or_load_dataset()
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
storage=storage,
|
storage=storage,
|
||||||
|
|
|
@ -97,7 +97,6 @@ class AlohaDataset(AbstractDataset):
|
||||||
collate_fn: Callable | None = None,
|
collate_fn: Callable | None = None,
|
||||||
writer: Writer | None = None,
|
writer: Writer | None = None,
|
||||||
transform: "torchrl.envs.Transform" = None,
|
transform: "torchrl.envs.Transform" = None,
|
||||||
dummy: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
dataset_id,
|
dataset_id,
|
||||||
|
@ -111,7 +110,6 @@ class AlohaDataset(AbstractDataset):
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
writer=writer,
|
writer=writer,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
dummy=dummy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -21,12 +21,7 @@ def make_offline_buffer(
|
||||||
overwrite_batch_size=None,
|
overwrite_batch_size=None,
|
||||||
overwrite_prefetch=None,
|
overwrite_prefetch=None,
|
||||||
stats_path=None,
|
stats_path=None,
|
||||||
# Don't actually load any data. This is a stand-in solution to get the transforms.
|
|
||||||
dummy=False,
|
|
||||||
):
|
):
|
||||||
if dummy and normalize and stats_path is None:
|
|
||||||
raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.")
|
|
||||||
|
|
||||||
if cfg.policy.balanced_sampling:
|
if cfg.policy.balanced_sampling:
|
||||||
assert cfg.online_steps > 0
|
assert cfg.online_steps > 0
|
||||||
batch_size = None
|
batch_size = None
|
||||||
|
@ -93,7 +88,6 @@ def make_offline_buffer(
|
||||||
root=DATA_DIR,
|
root=DATA_DIR,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||||
dummy=dummy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.policy.name == "tdmpc":
|
if cfg.policy.name == "tdmpc":
|
||||||
|
|
|
@ -100,7 +100,6 @@ class PushtDataset(AbstractDataset):
|
||||||
collate_fn: Callable | None = None,
|
collate_fn: Callable | None = None,
|
||||||
writer: Writer | None = None,
|
writer: Writer | None = None,
|
||||||
transform: "torchrl.envs.Transform" = None,
|
transform: "torchrl.envs.Transform" = None,
|
||||||
dummy: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
dataset_id,
|
dataset_id,
|
||||||
|
@ -114,7 +113,6 @@ class PushtDataset(AbstractDataset):
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
writer=writer,
|
writer=writer,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
dummy=dummy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _download_and_preproc_obsolete(self):
|
def _download_and_preproc_obsolete(self):
|
||||||
|
|
|
@ -51,7 +51,6 @@ class SimxarmDataset(AbstractDataset):
|
||||||
collate_fn: Callable | None = None,
|
collate_fn: Callable | None = None,
|
||||||
writer: Writer | None = None,
|
writer: Writer | None = None,
|
||||||
transform: "torchrl.envs.Transform" = None,
|
transform: "torchrl.envs.Transform" = None,
|
||||||
dummy: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
dataset_id,
|
dataset_id,
|
||||||
|
@ -65,7 +64,6 @@ class SimxarmDataset(AbstractDataset):
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
writer=writer,
|
writer=writer,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
dummy=dummy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _download_and_preproc_obsolete(self):
|
def _download_and_preproc_obsolete(self):
|
||||||
|
|
|
@ -194,7 +194,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
|
|
||||||
logging.info("Making transforms.")
|
logging.info("Making transforms.")
|
||||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||||
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path, dummy=stats_path is not None)
|
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
|
||||||
|
|
||||||
logging.info("Making environment.")
|
logging.info("Making environment.")
|
||||||
env = make_env(cfg, transform=offline_buffer.transform)
|
env = make_env(cfg, transform=offline_buffer.transform)
|
||||||
|
|
|
@ -2,8 +2,9 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
|
from lerobot.common.utils import init_hydra_config
|
||||||
|
|
||||||
from .utils import DEVICE, init_config
|
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -18,7 +19,10 @@ from .utils import DEVICE, init_config
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_factory(env_name, dataset_id):
|
def test_factory(env_name, dataset_id):
|
||||||
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"])
|
cfg = init_hydra_config(
|
||||||
|
DEFAULT_CONFIG_PATH,
|
||||||
|
overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]
|
||||||
|
)
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
for key in offline_buffer.image_keys:
|
for key in offline_buffer.image_keys:
|
||||||
img = offline_buffer[0].get(key)
|
img = offline_buffer[0].get(key)
|
||||||
|
|
|
@ -7,8 +7,9 @@ from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.pusht.env import PushtEnv
|
from lerobot.common.envs.pusht.env import PushtEnv
|
||||||
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
||||||
|
from lerobot.common.utils import init_hydra_config
|
||||||
|
|
||||||
from .utils import DEVICE, init_config
|
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
|
||||||
def print_spec_rollout(env):
|
def print_spec_rollout(env):
|
||||||
|
@ -89,7 +90,10 @@ def test_pusht(from_pixels, pixels_only):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_factory(env_name):
|
def test_factory(env_name):
|
||||||
cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"])
|
cfg = init_hydra_config(
|
||||||
|
DEFAULT_CONFIG_PATH,
|
||||||
|
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||||
|
)
|
||||||
|
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
from omegaconf import open_dict
|
|
||||||
import pytest
|
import pytest
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
|
@ -10,8 +9,8 @@ from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
|
from lerobot.common.utils import init_hydra_config
|
||||||
from .utils import DEVICE, init_config
|
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name,policy_name,extra_overrides",
|
"env_name,policy_name,extra_overrides",
|
||||||
|
@ -34,7 +33,8 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
|
||||||
- Updating the policy.
|
- Updating the policy.
|
||||||
- Using the policy to select actions at inference time.
|
- Using the policy to select actions at inference time.
|
||||||
"""
|
"""
|
||||||
cfg = init_config(
|
cfg = init_hydra_config(
|
||||||
|
DEFAULT_CONFIG_PATH,
|
||||||
overrides=[
|
overrides=[
|
||||||
f"env={env_name}",
|
f"env={env_name}",
|
||||||
f"policy={policy_name}",
|
f"policy={policy_name}",
|
||||||
|
|
|
@ -1,13 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import hydra
|
|
||||||
from hydra import compose, initialize
|
|
||||||
|
|
||||||
CONFIG_PATH = "../lerobot/configs"
|
# Pass this as the first argument to init_hydra_config.
|
||||||
|
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||||
|
|
||||||
DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda")
|
DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda")
|
||||||
|
|
||||||
def init_config(config_name="default", overrides=None):
|
|
||||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
|
||||||
initialize(config_path=CONFIG_PATH)
|
|
||||||
cfg = compose(config_name=config_name, overrides=overrides)
|
|
||||||
return cfg
|
|
||||||
|
|
Loading…
Reference in New Issue