diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index 46409041..bb73167b 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -1 +1,39 @@ -# TODO +""" +This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local +training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first. +""" + +from pathlib import Path + +from huggingface_hub import snapshot_download + +from lerobot.common.utils import init_hydra_config +from lerobot.scripts.eval import eval + +# Get a pretrained policy from the hub. +hub_id = "lerobot/diffusion_policy_pusht_image" +folder = Path(snapshot_download(hub_id)) +# OR uncomment the following to evaluate a policy from the local outputs/train folder. +folder = Path("outputs/train/example_pusht_diffusion") + +config_path = folder / "config.yaml" +weights_path = folder / "model.pt" +stats_path = folder / "stats.pth" # normalization stats + +# Override some config parameters to do with evaluation. +overrides = [ + f"policy.pretrained_model_path={weights_path}", + "eval_episodes=10", + "rollout_batch_size=10", + "device=cuda", +] + +# Create a Hydra config. +cfg = init_hydra_config(config_path, overrides) + +# Evaluate the policy and save the outputs including metrics and videos. +eval( + cfg, + out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}", + stats_path=stats_path, +) diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 46409041..01a4cf76 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -1 +1,55 @@ -# TODO +"""This scripts demonstrates how to train Diffusion Policy on the PushT environment. + +Once you have trained a model with this script, you can try to evaluate it on +examples/2_evaluate_pretrained_policy.py +""" + +import os +from pathlib import Path + +import torch +from omegaconf import OmegaConf +from tqdm import trange + +from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.policies.diffusion.policy import DiffusionPolicy +from lerobot.common.utils import init_hydra_config + +output_directory = Path("outputs/train/example_pusht_diffusion") +os.makedirs(output_directory, exist_ok=True) + +overrides = [ + "env=pusht", + "policy=diffusion", + # Adjust as you prefer. 5000 steps are needed to get something worth evaluating. + "offline_steps=5000", + "log_freq=250", + "device=cuda", +] + +cfg = init_hydra_config("lerobot/configs/default.yaml", overrides) + +policy = DiffusionPolicy( + cfg=cfg.policy, + cfg_device=cfg.device, + cfg_noise_scheduler=cfg.noise_scheduler, + cfg_rgb_model=cfg.rgb_model, + cfg_obs_encoder=cfg.obs_encoder, + cfg_optimizer=cfg.optimizer, + cfg_ema=cfg.ema, + n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, + **cfg.policy, +) +policy.train() + +offline_buffer = make_offline_buffer(cfg) + +for offline_step in trange(cfg.offline_steps): + train_info = policy.update(offline_buffer, offline_step) + if offline_step % cfg.log_freq == 0: + print(train_info) + +# Save the policy, configuration, and normalization stats for later use. +policy.save(output_directory / "model.pt") +OmegaConf.save(cfg, output_directory / "config.yaml") +torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth") diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index a81de49b..c05d25c0 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -59,6 +59,8 @@ class AbstractDataset(TensorDictReplayBuffer): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, + # Don't actually load any data. This is a stand-in solution to get the transforms. + dummy: bool = False, ): assert ( self.available_datasets is not None @@ -77,7 +79,7 @@ class AbstractDataset(TensorDictReplayBuffer): f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})." ) - storage = self._download_or_load_dataset() + storage = self._download_or_load_dataset() if not dummy else None super().__init__( storage=storage, diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 031c2cd3..83d1581a 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -97,6 +97,7 @@ class AlohaDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, + dummy: bool = False, ): super().__init__( dataset_id, @@ -110,6 +111,7 @@ class AlohaDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, + dummy=dummy, ) @property diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 4212e023..276dc761 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -21,7 +21,12 @@ def make_offline_buffer( overwrite_batch_size=None, overwrite_prefetch=None, stats_path=None, + # Don't actually load any data. This is a stand-in solution to get the transforms. + dummy=False, ): + if dummy and normalize and stats_path is None: + raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.") + if cfg.policy.balanced_sampling: assert cfg.online_steps > 0 batch_size = None @@ -93,6 +98,7 @@ def make_offline_buffer( root=DATA_DIR, pin_memory=pin_memory, prefetch=prefetch if isinstance(prefetch, int) else None, + dummy=dummy, ) if cfg.policy.name == "tdmpc": diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 624fb140..d167f3ea 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -100,6 +100,7 @@ class PushtDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, + dummy: bool = False, ): super().__init__( dataset_id, @@ -113,6 +114,7 @@ class PushtDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, + dummy=dummy, ) def _download_and_preproc_obsolete(self): diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index dc30e69e..06931d3f 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -51,6 +51,7 @@ class SimxarmDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, + dummy: bool = False, ): super().__init__( dataset_id, @@ -64,6 +65,7 @@ class SimxarmDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, + dummy=dummy, ) def _download_and_preproc_obsolete(self): diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index 2af1d966..86383cdc 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -1,9 +1,13 @@ import logging +import os.path as osp import random from datetime import datetime +from pathlib import Path +import hydra import numpy as np import torch +from omegaconf import DictConfig def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device: @@ -63,3 +67,29 @@ def format_big_number(num): num /= divisor return num + + +def _relative_path_between(path1: Path, path2: Path) -> Path: + """Returns path1 relative to path2.""" + path1 = path1.absolute() + path2 = path2.absolute() + try: + return path1.relative_to(path2) + except ValueError: # most likely because path1 is not a subpath of path2 + common_parts = Path(osp.commonpath([path1, path2])).parts + return Path( + "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) + ) + + +def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig: + """Initialize a Hydra config given only the path to the relevant config file. + + For config resolution, it is assumed that the config file's parent is the Hydra config dir. + """ + # Hydra needs a path relative to this file. + hydra.initialize( + str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)) + ) + cfg = hydra.compose(Path(config_path).stem, overrides) + return cfg diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 1de0bb0e..72517504 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -30,14 +30,12 @@ python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10 import argparse import json import logging -import os.path as osp import threading import time from datetime import datetime as dt from pathlib import Path import einops -import hydra import imageio import numpy as np import torch @@ -52,7 +50,7 @@ from lerobot.common.envs.factory import make_env from lerobot.common.logger import log_output_dir from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import get_safe_torch_device, init_logging, set_global_seed +from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed def write_video(video_path, stacked_frames, fps): @@ -195,7 +193,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None): log_output_dir(out_dir) logging.info("Making transforms.") - offline_buffer = make_offline_buffer(cfg, stats_path=stats_path) + # TODO(alexander-soare): Completely decouple datasets from evaluation. + offline_buffer = make_offline_buffer(cfg, stats_path=stats_path, dummy=stats_path is not None) logging.info("Making environment.") env = make_env(cfg, transform=offline_buffer.transform) @@ -229,19 +228,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None): logging.info("End of eval") -def _relative_path_between(path1: Path, path2: Path) -> Path: - """Returns path1 relative to path2.""" - path1 = path1.absolute() - path2 = path2.absolute() - try: - return path1.relative_to(path2) - except ValueError: # most likely because path1 is not a subpath of path2 - common_parts = Path(osp.commonpath([path1, path2])).parts - return Path( - "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) - ) - - if __name__ == "__main__": parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter @@ -259,19 +245,14 @@ if __name__ == "__main__": if args.config is not None: # Note: For the config_path, Hydra wants a path relative to this script file. - hydra.initialize( - config_path=str( - _relative_path_between(Path(args.config).absolute().parent, Path(__file__).parent) - ) - ) - cfg = hydra.compose(Path(args.config).stem, args.overrides) + cfg = init_hydra_config(args.config, args.overrides) # TODO(alexander-soare): Save and load stats in trained model directory. stats_path = None elif args.hub_id is not None: folder = Path(snapshot_download(args.hub_id, revision="v1.0")) - cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent))) - cfg = hydra.compose("config", args.overrides) - cfg.policy.pretrained_model_path = folder / "model.pt" + cfg = init_hydra_config( + folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides] + ) stats_path = folder / "stats.pth" eval( diff --git a/tests/test_examples.py b/tests/test_examples.py index 6c21eb4c..9da7a663 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,19 +1,56 @@ -import pytest from pathlib import Path -@pytest.mark.parametrize( - "path", - [ - "examples/1_visualize_dataset.py", - "examples/2_evaluate_pretrained_policy.py", - "examples/3_train_policy.py", - ], -) -def test_example(path): - with open(path, 'r') as file: +def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str: + for f, r in zip(finds, replaces): + assert f in text + text = text.replace(f, r) + return text + + +def test_example_1(): + path = "examples/1_visualize_dataset.py" + + with open(path, "r") as file: file_contents = file.read() exec(file_contents) - if path == "examples/1_visualize_dataset.py": - assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists() + assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists() + + +def test_examples_3_and_2(): + """ + Train a model with example 3, check the outputs. + Evaluate the trained model with example 2, check the outputs. + """ + + path = "examples/3_train_policy.py" + + with open(path, "r") as file: + file_contents = file.read() + + # Do less steps and use CPU. + file_contents = _find_and_replace( + file_contents, + ['"offline_steps=5000"', '"device=cuda"'], + ['"offline_steps=1"', '"device=cpu"'], + ) + + exec(file_contents) + + for file_name in ["model.pt", "stats.pth", "config.yaml"]: + assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() + + path = "examples/2_evaluate_pretrained_policy.py" + + with open(path, "r") as file: + file_contents = file.read() + + # Do less evals and use CPU. + file_contents = _find_and_replace( + file_contents, + ['"eval_episodes=10"', '"rollout_batch_size=10"', '"device=cuda"'], + ['"eval_episodes=1"', '"rollout_batch_size=1"','"device=cpu"'], + ) + + assert Path(f"outputs/train/example_pusht_diffusion").exists() \ No newline at end of file