diff --git a/.gitignore b/.gitignore index ad9892d4..f07ce341 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ tmp wandb data outputs +eval_outputs .vscode rl diff --git a/README.md b/README.md index 18af3242..076252c8 100644 --- a/README.md +++ b/README.md @@ -223,3 +223,38 @@ Finally, you might want to mock the dataset if you need to update the unit tests ``` python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET ``` + +**Models** + +Once you have trained a model you may upload it to the HuggingFace hub. + +Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME. + +Secondly, assuming you have trained a model, you need: + +- `config.yaml` which you can get from the `.hydra` directory of your training output folder. +- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one). +- `staths.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`). + +To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying): + +``` +to_upload + ├── config.yaml + ├── model.pt + └── stats.pth +``` + +With the folder prepared, run the following with a desired revision ID. + +``` +huggingface-cli upload $HUB_ID to_upload --revision $REVISION_ID +``` + +If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers): + +``` +huggingface-cli upload $HUB_ID to_upload +``` + +See `eval.py` for an example of how a user may use your model. diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 3f4772c4..40a49cc1 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -14,7 +14,12 @@ DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None def make_offline_buffer( - cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None + cfg, + overwrite_sampler=None, + normalize=True, + overwrite_batch_size=None, + overwrite_prefetch=None, + stats_path=None, ): if cfg.policy.balanced_sampling: assert cfg.online_steps > 0 @@ -98,10 +103,12 @@ def make_offline_buffer( transforms = [Prod(in_keys=img_keys, prod=1 / 255)] if normalize: - # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec - stats = offline_buffer.compute_or_load_stats() + # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, + # min_max_from_spec + stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path) - # we only normalize the state and action, since the images are usually normalized inside the model for now (except for tdmpc: see the following) + # we only normalize the state and action, since the images are usually normalized inside the model for + # now (except for tdmpc: see the following) in_keys = [("observation", "state"), ("action")] if cfg.policy.name == "tdmpc": diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 76deb2fe..480fca59 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -1,6 +1,36 @@ +"""Evaluate a policy on an environment by running rollouts and computing metrics. + +The script may be run in one of two ways: + +1. By providing the path to a config file with the --config argument. +2. By providing a HuggingFace Hub ID with the --hub-id argument. You may also provide a revision number with the + --revision argument. + +In either case, it is possible to override config arguments by adding a list of config.key=value arguments. + +Examples: + +You have a specific config file to go with trained model weights, and want to run 10 episodes. + +``` +python lerobot/scripts/eval.py --config PATH/TO/FOLDER/config.yaml \ + policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth` eval_episodes=10 +``` + +You have a HuggingFace Hub ID, you know which revision you want, and want to run 10 episodes (note that in this case, +you don't need to specify which weights to use): + +``` +python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10 +``` +""" + +import argparse import logging +import os.path as osp import threading import time +from datetime import datetime as dt from pathlib import Path import einops @@ -9,6 +39,7 @@ import imageio import numpy as np import torch import tqdm +from huggingface_hub import snapshot_download from tensordict.nn import TensorDictModule from torchrl.envs import EnvBase from torchrl.envs.batched_envs import BatchedEnvBase @@ -65,8 +96,8 @@ def eval_policy( callback=maybe_render_frame, break_when_any_done=env.batch_size[0] == 1, ) - # Figure out where in each rollout sequence the first done condition was encountered (results after this won't - # be included). + # Figure out where in each rollout sequence the first done condition was encountered (results after + # this won't be included). # Note: this assumes that the shape of the done key is (batch_size, max_steps, 1). # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. rollout_steps = rollout["next", "done"].shape[1] @@ -119,12 +150,7 @@ def eval_policy( return info -@hydra.main(version_base=None, config_name="default", config_path="../configs") -def eval_cli(cfg: dict): - eval(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) - - -def eval(cfg: dict, out_dir=None): +def eval(cfg: dict, out_dir=None, stats_path=None): if out_dir is None: raise NotImplementedError() @@ -139,10 +165,10 @@ def eval(cfg: dict, out_dir=None): log_output_dir(out_dir) - logging.info("make_offline_buffer") - offline_buffer = make_offline_buffer(cfg) + logging.info("Making transforms.") + offline_buffer = make_offline_buffer(cfg, stats_path=stats_path) - logging.info("make_env") + logging.info("Making environment.") env = make_env(cfg, transform=offline_buffer.transform) if cfg.policy.pretrained_model_path: @@ -170,5 +196,52 @@ def eval(cfg: dict, out_dir=None): logging.info("End of eval") +def _relative_path_between(path1: Path, path2: Path) -> Path: + """Returns path1 relative to path2.""" + path1 = path1.absolute() + path2 = path2.absolute() + try: + return path1.relative_to(path2) + except ValueError: # most likely because path1 is not a subpath of path2 + common_parts = Path(osp.commonpath([path1, path2])).parts + return Path( + "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) + ) + + if __name__ == "__main__": - eval_cli() + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--config", help="Path to a specific yaml config you want to use.") + group.add_argument("--hub-id", help="HuggingFace Hub ID for a pretrained model.") + parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.") + parser.add_argument( + "overrides", + nargs="*", + help="Any key=value arguments to override config values (use dots for.nested=overrides)", + ) + args = parser.parse_args() + + if args.config is not None: + # Note: For the config_path, Hydra wants a path relative to this script file. + hydra.initialize( + config_path=str( + _relative_path_between(Path(args.config).absolute().parent, Path(__file__).parent) + ) + ) + cfg = hydra.compose(Path(args.config).stem, args.overrides) + # TODO(alexander-soare): Save and load stats in trained model directory. + stats_path = None + elif args.hub_id is not None: + folder = Path(snapshot_download(args.hub_id, revision="v1.0")) + cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent))) + cfg = hydra.compose("config", args.overrides) + cfg.policy.pretrained_model_path = folder / "model.pt" + stats_path = folder / "stats.pth" + + eval( + cfg, + out_dir=f"eval_outputs/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}", + )