From 1ed0110900db5d8db8cbf7757705c65025a61321 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 26 Mar 2024 16:13:40 +0000 Subject: [PATCH 1/5] finish examples 2 and 3 --- examples/2_evaluate_pretrained_policy.py | 40 ++++++++++++++- examples/3_train_policy.py | 56 ++++++++++++++++++++- lerobot/common/datasets/abstract.py | 4 +- lerobot/common/datasets/aloha.py | 2 + lerobot/common/datasets/factory.py | 6 +++ lerobot/common/datasets/pusht.py | 2 + lerobot/common/datasets/simxarm.py | 2 + lerobot/common/utils.py | 30 +++++++++++ lerobot/scripts/eval.py | 33 +++---------- tests/test_examples.py | 63 +++++++++++++++++++----- 10 files changed, 196 insertions(+), 42 deletions(-) diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index 46409041..bb73167b 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -1 +1,39 @@ -# TODO +""" +This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local +training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first. +""" + +from pathlib import Path + +from huggingface_hub import snapshot_download + +from lerobot.common.utils import init_hydra_config +from lerobot.scripts.eval import eval + +# Get a pretrained policy from the hub. +hub_id = "lerobot/diffusion_policy_pusht_image" +folder = Path(snapshot_download(hub_id)) +# OR uncomment the following to evaluate a policy from the local outputs/train folder. +folder = Path("outputs/train/example_pusht_diffusion") + +config_path = folder / "config.yaml" +weights_path = folder / "model.pt" +stats_path = folder / "stats.pth" # normalization stats + +# Override some config parameters to do with evaluation. +overrides = [ + f"policy.pretrained_model_path={weights_path}", + "eval_episodes=10", + "rollout_batch_size=10", + "device=cuda", +] + +# Create a Hydra config. +cfg = init_hydra_config(config_path, overrides) + +# Evaluate the policy and save the outputs including metrics and videos. +eval( + cfg, + out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}", + stats_path=stats_path, +) diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 46409041..01a4cf76 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -1 +1,55 @@ -# TODO +"""This scripts demonstrates how to train Diffusion Policy on the PushT environment. + +Once you have trained a model with this script, you can try to evaluate it on +examples/2_evaluate_pretrained_policy.py +""" + +import os +from pathlib import Path + +import torch +from omegaconf import OmegaConf +from tqdm import trange + +from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.policies.diffusion.policy import DiffusionPolicy +from lerobot.common.utils import init_hydra_config + +output_directory = Path("outputs/train/example_pusht_diffusion") +os.makedirs(output_directory, exist_ok=True) + +overrides = [ + "env=pusht", + "policy=diffusion", + # Adjust as you prefer. 5000 steps are needed to get something worth evaluating. + "offline_steps=5000", + "log_freq=250", + "device=cuda", +] + +cfg = init_hydra_config("lerobot/configs/default.yaml", overrides) + +policy = DiffusionPolicy( + cfg=cfg.policy, + cfg_device=cfg.device, + cfg_noise_scheduler=cfg.noise_scheduler, + cfg_rgb_model=cfg.rgb_model, + cfg_obs_encoder=cfg.obs_encoder, + cfg_optimizer=cfg.optimizer, + cfg_ema=cfg.ema, + n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, + **cfg.policy, +) +policy.train() + +offline_buffer = make_offline_buffer(cfg) + +for offline_step in trange(cfg.offline_steps): + train_info = policy.update(offline_buffer, offline_step) + if offline_step % cfg.log_freq == 0: + print(train_info) + +# Save the policy, configuration, and normalization stats for later use. +policy.save(output_directory / "model.pt") +OmegaConf.save(cfg, output_directory / "config.yaml") +torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth") diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index a81de49b..c05d25c0 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -59,6 +59,8 @@ class AbstractDataset(TensorDictReplayBuffer): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, + # Don't actually load any data. This is a stand-in solution to get the transforms. + dummy: bool = False, ): assert ( self.available_datasets is not None @@ -77,7 +79,7 @@ class AbstractDataset(TensorDictReplayBuffer): f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})." ) - storage = self._download_or_load_dataset() + storage = self._download_or_load_dataset() if not dummy else None super().__init__( storage=storage, diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 031c2cd3..83d1581a 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -97,6 +97,7 @@ class AlohaDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, + dummy: bool = False, ): super().__init__( dataset_id, @@ -110,6 +111,7 @@ class AlohaDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, + dummy=dummy, ) @property diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 4212e023..276dc761 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -21,7 +21,12 @@ def make_offline_buffer( overwrite_batch_size=None, overwrite_prefetch=None, stats_path=None, + # Don't actually load any data. This is a stand-in solution to get the transforms. + dummy=False, ): + if dummy and normalize and stats_path is None: + raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.") + if cfg.policy.balanced_sampling: assert cfg.online_steps > 0 batch_size = None @@ -93,6 +98,7 @@ def make_offline_buffer( root=DATA_DIR, pin_memory=pin_memory, prefetch=prefetch if isinstance(prefetch, int) else None, + dummy=dummy, ) if cfg.policy.name == "tdmpc": diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 624fb140..d167f3ea 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -100,6 +100,7 @@ class PushtDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, + dummy: bool = False, ): super().__init__( dataset_id, @@ -113,6 +114,7 @@ class PushtDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, + dummy=dummy, ) def _download_and_preproc_obsolete(self): diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index dc30e69e..06931d3f 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -51,6 +51,7 @@ class SimxarmDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, + dummy: bool = False, ): super().__init__( dataset_id, @@ -64,6 +65,7 @@ class SimxarmDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, + dummy=dummy, ) def _download_and_preproc_obsolete(self): diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index 2af1d966..86383cdc 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -1,9 +1,13 @@ import logging +import os.path as osp import random from datetime import datetime +from pathlib import Path +import hydra import numpy as np import torch +from omegaconf import DictConfig def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device: @@ -63,3 +67,29 @@ def format_big_number(num): num /= divisor return num + + +def _relative_path_between(path1: Path, path2: Path) -> Path: + """Returns path1 relative to path2.""" + path1 = path1.absolute() + path2 = path2.absolute() + try: + return path1.relative_to(path2) + except ValueError: # most likely because path1 is not a subpath of path2 + common_parts = Path(osp.commonpath([path1, path2])).parts + return Path( + "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) + ) + + +def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig: + """Initialize a Hydra config given only the path to the relevant config file. + + For config resolution, it is assumed that the config file's parent is the Hydra config dir. + """ + # Hydra needs a path relative to this file. + hydra.initialize( + str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)) + ) + cfg = hydra.compose(Path(config_path).stem, overrides) + return cfg diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 1de0bb0e..72517504 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -30,14 +30,12 @@ python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10 import argparse import json import logging -import os.path as osp import threading import time from datetime import datetime as dt from pathlib import Path import einops -import hydra import imageio import numpy as np import torch @@ -52,7 +50,7 @@ from lerobot.common.envs.factory import make_env from lerobot.common.logger import log_output_dir from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import get_safe_torch_device, init_logging, set_global_seed +from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed def write_video(video_path, stacked_frames, fps): @@ -195,7 +193,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None): log_output_dir(out_dir) logging.info("Making transforms.") - offline_buffer = make_offline_buffer(cfg, stats_path=stats_path) + # TODO(alexander-soare): Completely decouple datasets from evaluation. + offline_buffer = make_offline_buffer(cfg, stats_path=stats_path, dummy=stats_path is not None) logging.info("Making environment.") env = make_env(cfg, transform=offline_buffer.transform) @@ -229,19 +228,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None): logging.info("End of eval") -def _relative_path_between(path1: Path, path2: Path) -> Path: - """Returns path1 relative to path2.""" - path1 = path1.absolute() - path2 = path2.absolute() - try: - return path1.relative_to(path2) - except ValueError: # most likely because path1 is not a subpath of path2 - common_parts = Path(osp.commonpath([path1, path2])).parts - return Path( - "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) - ) - - if __name__ == "__main__": parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter @@ -259,19 +245,14 @@ if __name__ == "__main__": if args.config is not None: # Note: For the config_path, Hydra wants a path relative to this script file. - hydra.initialize( - config_path=str( - _relative_path_between(Path(args.config).absolute().parent, Path(__file__).parent) - ) - ) - cfg = hydra.compose(Path(args.config).stem, args.overrides) + cfg = init_hydra_config(args.config, args.overrides) # TODO(alexander-soare): Save and load stats in trained model directory. stats_path = None elif args.hub_id is not None: folder = Path(snapshot_download(args.hub_id, revision="v1.0")) - cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent))) - cfg = hydra.compose("config", args.overrides) - cfg.policy.pretrained_model_path = folder / "model.pt" + cfg = init_hydra_config( + folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides] + ) stats_path = folder / "stats.pth" eval( diff --git a/tests/test_examples.py b/tests/test_examples.py index 6c21eb4c..9da7a663 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,19 +1,56 @@ -import pytest from pathlib import Path -@pytest.mark.parametrize( - "path", - [ - "examples/1_visualize_dataset.py", - "examples/2_evaluate_pretrained_policy.py", - "examples/3_train_policy.py", - ], -) -def test_example(path): - with open(path, 'r') as file: +def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str: + for f, r in zip(finds, replaces): + assert f in text + text = text.replace(f, r) + return text + + +def test_example_1(): + path = "examples/1_visualize_dataset.py" + + with open(path, "r") as file: file_contents = file.read() exec(file_contents) - if path == "examples/1_visualize_dataset.py": - assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists() + assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists() + + +def test_examples_3_and_2(): + """ + Train a model with example 3, check the outputs. + Evaluate the trained model with example 2, check the outputs. + """ + + path = "examples/3_train_policy.py" + + with open(path, "r") as file: + file_contents = file.read() + + # Do less steps and use CPU. + file_contents = _find_and_replace( + file_contents, + ['"offline_steps=5000"', '"device=cuda"'], + ['"offline_steps=1"', '"device=cpu"'], + ) + + exec(file_contents) + + for file_name in ["model.pt", "stats.pth", "config.yaml"]: + assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() + + path = "examples/2_evaluate_pretrained_policy.py" + + with open(path, "r") as file: + file_contents = file.read() + + # Do less evals and use CPU. + file_contents = _find_and_replace( + file_contents, + ['"eval_episodes=10"', '"rollout_batch_size=10"', '"device=cuda"'], + ['"eval_episodes=1"', '"rollout_batch_size=1"','"device=cpu"'], + ) + + assert Path(f"outputs/train/example_pusht_diffusion").exists() \ No newline at end of file From be4441c7ff423435abae2d08edb35a7c53df87f3 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 26 Mar 2024 16:28:16 +0000 Subject: [PATCH 2/5] update README --- README.md | 22 ++++------------------ examples/2_evaluate_pretrained_policy.py | 2 +- tests/test_examples.py | 22 ++++++++++++++++++---- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 35911869..d71fad79 100644 --- a/README.md +++ b/README.md @@ -135,11 +135,7 @@ hydra.run.dir=outputs/visualize_dataset/example ### Evaluate a pretrained policy -You can import our environment class, download pretrained policies from the HuggingFace hub, and use our rollout utilities with rendering: -```python -""" Copy pasted from `examples/2_evaluate_pretrained_policy.py` -# TODO -``` +Check out [example 2](./examples/2_evaluate_pretrained_policy.py) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation. Or you can achieve the same result by executing our script from the command line: ```bash @@ -150,7 +146,7 @@ eval_episodes=10 \ hydra.run.dir=outputs/eval/example_hub ``` -After launching training of your own policy, you can also re-evaluate the checkpoints with: +After training your own policy, you can also re-evaluate the checkpoints with: ```bash python lerobot/scripts/eval.py \ --config PATH/TO/FOLDER/config.yaml \ @@ -163,19 +159,9 @@ See `python lerobot/scripts/eval.py --help` for more instructions. ### Train your own policy -You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub): -```python -""" Copy pasted from `examples/3_train_policy.py` -# TODO -``` +You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub): check out [example 3](./examples/3_train_policy.py). After you run this, you may want to revisit [example 2](./examples/2_evaluate_pretrained_policy.py) to evaluate your training output! -Or you can achieve the same result by executing our script from the command line: -```bash -python lerobot/scripts/train.py \ -hydra.run.dir=outputs/train/example -``` - -You can easily train any policy on any environment: +In general, you can use our training script to easily train any policy on any environment: ```bash python lerobot/scripts/train.py \ env=aloha \ diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index bb73167b..be6abd1b 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -14,7 +14,7 @@ from lerobot.scripts.eval import eval hub_id = "lerobot/diffusion_policy_pusht_image" folder = Path(snapshot_download(hub_id)) # OR uncomment the following to evaluate a policy from the local outputs/train folder. -folder = Path("outputs/train/example_pusht_diffusion") +# folder = Path("outputs/train/example_pusht_diffusion") config_path = folder / "config.yaml" weights_path = folder / "model.pt" diff --git a/tests/test_examples.py b/tests/test_examples.py index 9da7a663..4263e452 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -46,11 +46,25 @@ def test_examples_3_and_2(): with open(path, "r") as file: file_contents = file.read() - # Do less evals and use CPU. + # Do less evals, use CPU, and use the local model. file_contents = _find_and_replace( file_contents, - ['"eval_episodes=10"', '"rollout_batch_size=10"', '"device=cuda"'], - ['"eval_episodes=1"', '"rollout_batch_size=1"','"device=cpu"'], + [ + '"eval_episodes=10"', + '"rollout_batch_size=10"', + '"device=cuda"', + '# folder = Path("outputs/train/example_pusht_diffusion")', + 'hub_id = "lerobot/diffusion_policy_pusht_image"', + "folder = Path(snapshot_download(hub_id)", + ], + [ + '"eval_episodes=1"', + '"rollout_batch_size=1"', + '"device=cpu"', + 'folder = Path("outputs/train/example_pusht_diffusion")', + "", + "", + ], ) - assert Path(f"outputs/train/example_pusht_diffusion").exists() \ No newline at end of file + assert Path(f"outputs/train/example_pusht_diffusion").exists() From 011f2d27febf57686fe5143a12ff6798a40e38c4 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 26 Mar 2024 16:40:54 +0000 Subject: [PATCH 3/5] fix tests --- lerobot/common/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index 86383cdc..7ed29334 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -87,6 +87,8 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D For config resolution, it is assumed that the config file's parent is the Hydra config dir. """ + # TODO(alexander-soare): Resolve configs without Hydra initialization. + hydra.core.global_hydra.GlobalHydra.instance().clear() # Hydra needs a path relative to this file. hydra.initialize( str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)) From 6cd671040fbff2c49778176aae263b27a4d943db Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 27 Mar 2024 13:22:14 +0000 Subject: [PATCH 4/5] fix revision --- README.md | 1 - lerobot/scripts/eval.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index d71fad79..0786c6d6 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,6 @@ Or you can achieve the same result by executing our script from the command line ```bash python lerobot/scripts/eval.py \ --hub-id lerobot/diffusion_policy_pusht_image \ ---revision v1.0 \ eval_episodes=10 \ hydra.run.dir=outputs/eval/example_hub ``` diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 72517504..2a3ab13b 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -249,7 +249,7 @@ if __name__ == "__main__": # TODO(alexander-soare): Save and load stats in trained model directory. stats_path = None elif args.hub_id is not None: - folder = Path(snapshot_download(args.hub_id, revision="v1.0")) + folder = Path(snapshot_download(args.hub_id, revision=args.revision)) cfg = init_hydra_config( folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides] ) From b7c9c330725450d86ef24957c96d7710cf2edaee Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 27 Mar 2024 18:33:48 +0000 Subject: [PATCH 5/5] revision --- lerobot/common/datasets/abstract.py | 4 +--- lerobot/common/datasets/aloha.py | 2 -- lerobot/common/datasets/factory.py | 6 ------ lerobot/common/datasets/pusht.py | 2 -- lerobot/common/datasets/simxarm.py | 2 -- lerobot/scripts/eval.py | 2 +- tests/test_datasets.py | 8 ++++++-- tests/test_envs.py | 8 ++++++-- tests/test_policies.py | 8 ++++---- tests/utils.py | 11 ++--------- 10 files changed, 20 insertions(+), 33 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index c05d25c0..a81de49b 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -59,8 +59,6 @@ class AbstractDataset(TensorDictReplayBuffer): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, - # Don't actually load any data. This is a stand-in solution to get the transforms. - dummy: bool = False, ): assert ( self.available_datasets is not None @@ -79,7 +77,7 @@ class AbstractDataset(TensorDictReplayBuffer): f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})." ) - storage = self._download_or_load_dataset() if not dummy else None + storage = self._download_or_load_dataset() super().__init__( storage=storage, diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 83d1581a..031c2cd3 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -97,7 +97,6 @@ class AlohaDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, - dummy: bool = False, ): super().__init__( dataset_id, @@ -111,7 +110,6 @@ class AlohaDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, - dummy=dummy, ) @property diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 4e02f704..04077034 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -21,12 +21,7 @@ def make_offline_buffer( overwrite_batch_size=None, overwrite_prefetch=None, stats_path=None, - # Don't actually load any data. This is a stand-in solution to get the transforms. - dummy=False, ): - if dummy and normalize and stats_path is None: - raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.") - if cfg.policy.balanced_sampling: assert cfg.online_steps > 0 batch_size = None @@ -93,7 +88,6 @@ def make_offline_buffer( root=DATA_DIR, pin_memory=pin_memory, prefetch=prefetch if isinstance(prefetch, int) else None, - dummy=dummy, ) if cfg.policy.name == "tdmpc": diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index d167f3ea..624fb140 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -100,7 +100,6 @@ class PushtDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, - dummy: bool = False, ): super().__init__( dataset_id, @@ -114,7 +113,6 @@ class PushtDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, - dummy=dummy, ) def _download_and_preproc_obsolete(self): diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 06931d3f..dc30e69e 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -51,7 +51,6 @@ class SimxarmDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, - dummy: bool = False, ): super().__init__( dataset_id, @@ -65,7 +64,6 @@ class SimxarmDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, - dummy=dummy, ) def _download_and_preproc_obsolete(self): diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 2a3ab13b..216769d6 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -194,7 +194,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None): logging.info("Making transforms.") # TODO(alexander-soare): Completely decouple datasets from evaluation. - offline_buffer = make_offline_buffer(cfg, stats_path=stats_path, dummy=stats_path is not None) + offline_buffer = make_offline_buffer(cfg, stats_path=stats_path) logging.info("Making environment.") env = make_env(cfg, transform=offline_buffer.transform) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 252e0046..adaefcf5 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -2,8 +2,9 @@ import pytest import torch from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.utils import init_hydra_config -from .utils import DEVICE, init_config +from .utils import DEVICE, DEFAULT_CONFIG_PATH @pytest.mark.parametrize( @@ -18,7 +19,10 @@ from .utils import DEVICE, init_config ], ) def test_factory(env_name, dataset_id): - cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]) + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, + overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"] + ) offline_buffer = make_offline_buffer(cfg) for key in offline_buffer.image_keys: img = offline_buffer[0].get(key) diff --git a/tests/test_envs.py b/tests/test_envs.py index 2beafbda..2bd5e65c 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -7,8 +7,9 @@ from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env from lerobot.common.envs.pusht.env import PushtEnv from lerobot.common.envs.simxarm.env import SimxarmEnv +from lerobot.common.utils import init_hydra_config -from .utils import DEVICE, init_config +from .utils import DEVICE, DEFAULT_CONFIG_PATH def print_spec_rollout(env): @@ -89,7 +90,10 @@ def test_pusht(from_pixels, pixels_only): ], ) def test_factory(env_name): - cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"]) + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, + overrides=[f"env={env_name}", f"device={DEVICE}"], + ) offline_buffer = make_offline_buffer(cfg) diff --git a/tests/test_policies.py b/tests/test_policies.py index d3dc0bc5..5d6b46d0 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,4 +1,3 @@ -from omegaconf import open_dict import pytest from tensordict import TensorDict from tensordict.nn import TensorDictModule @@ -10,8 +9,8 @@ from lerobot.common.policies.factory import make_policy from lerobot.common.envs.factory import make_env from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.policies.abstract import AbstractPolicy - -from .utils import DEVICE, init_config +from lerobot.common.utils import init_hydra_config +from .utils import DEVICE, DEFAULT_CONFIG_PATH @pytest.mark.parametrize( "env_name,policy_name,extra_overrides", @@ -34,7 +33,8 @@ def test_concrete_policy(env_name, policy_name, extra_overrides): - Updating the policy. - Using the policy to select actions at inference time. """ - cfg = init_config( + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, overrides=[ f"env={env_name}", f"policy={policy_name}", diff --git a/tests/utils.py b/tests/utils.py index 55709330..6169c3b6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,13 +1,6 @@ import os -import hydra -from hydra import compose, initialize -CONFIG_PATH = "../lerobot/configs" +# Pass this as the first argument to init_hydra_config. +DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml" DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda") - -def init_config(config_name="default", overrides=None): - hydra.core.global_hydra.GlobalHydra.instance().clear() - initialize(config_path=CONFIG_PATH) - cfg = compose(config_name=config_name, overrides=overrides) - return cfg