From 8720c568d0c962f2f59e31f378fdfb99e1d77f17 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 22 Mar 2024 10:26:55 +0000 Subject: [PATCH 1/5] Add ability to eval hub model --- .gitignore | 1 + README.md | 35 +++++++++++ lerobot/common/datasets/factory.py | 15 +++-- lerobot/scripts/eval.py | 97 ++++++++++++++++++++++++++---- 4 files changed, 132 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index ad9892d4..f07ce341 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ tmp wandb data outputs +eval_outputs .vscode rl diff --git a/README.md b/README.md index 18af3242..076252c8 100644 --- a/README.md +++ b/README.md @@ -223,3 +223,38 @@ Finally, you might want to mock the dataset if you need to update the unit tests ``` python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET ``` + +**Models** + +Once you have trained a model you may upload it to the HuggingFace hub. + +Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME. + +Secondly, assuming you have trained a model, you need: + +- `config.yaml` which you can get from the `.hydra` directory of your training output folder. +- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one). +- `staths.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`). + +To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying): + +``` +to_upload + ├── config.yaml + ├── model.pt + └── stats.pth +``` + +With the folder prepared, run the following with a desired revision ID. + +``` +huggingface-cli upload $HUB_ID to_upload --revision $REVISION_ID +``` + +If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers): + +``` +huggingface-cli upload $HUB_ID to_upload +``` + +See `eval.py` for an example of how a user may use your model. diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 3f4772c4..40a49cc1 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -14,7 +14,12 @@ DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None def make_offline_buffer( - cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None + cfg, + overwrite_sampler=None, + normalize=True, + overwrite_batch_size=None, + overwrite_prefetch=None, + stats_path=None, ): if cfg.policy.balanced_sampling: assert cfg.online_steps > 0 @@ -98,10 +103,12 @@ def make_offline_buffer( transforms = [Prod(in_keys=img_keys, prod=1 / 255)] if normalize: - # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec - stats = offline_buffer.compute_or_load_stats() + # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, + # min_max_from_spec + stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path) - # we only normalize the state and action, since the images are usually normalized inside the model for now (except for tdmpc: see the following) + # we only normalize the state and action, since the images are usually normalized inside the model for + # now (except for tdmpc: see the following) in_keys = [("observation", "state"), ("action")] if cfg.policy.name == "tdmpc": diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 76deb2fe..480fca59 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -1,6 +1,36 @@ +"""Evaluate a policy on an environment by running rollouts and computing metrics. + +The script may be run in one of two ways: + +1. By providing the path to a config file with the --config argument. +2. By providing a HuggingFace Hub ID with the --hub-id argument. You may also provide a revision number with the + --revision argument. + +In either case, it is possible to override config arguments by adding a list of config.key=value arguments. + +Examples: + +You have a specific config file to go with trained model weights, and want to run 10 episodes. + +``` +python lerobot/scripts/eval.py --config PATH/TO/FOLDER/config.yaml \ + policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth` eval_episodes=10 +``` + +You have a HuggingFace Hub ID, you know which revision you want, and want to run 10 episodes (note that in this case, +you don't need to specify which weights to use): + +``` +python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10 +``` +""" + +import argparse import logging +import os.path as osp import threading import time +from datetime import datetime as dt from pathlib import Path import einops @@ -9,6 +39,7 @@ import imageio import numpy as np import torch import tqdm +from huggingface_hub import snapshot_download from tensordict.nn import TensorDictModule from torchrl.envs import EnvBase from torchrl.envs.batched_envs import BatchedEnvBase @@ -65,8 +96,8 @@ def eval_policy( callback=maybe_render_frame, break_when_any_done=env.batch_size[0] == 1, ) - # Figure out where in each rollout sequence the first done condition was encountered (results after this won't - # be included). + # Figure out where in each rollout sequence the first done condition was encountered (results after + # this won't be included). # Note: this assumes that the shape of the done key is (batch_size, max_steps, 1). # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. rollout_steps = rollout["next", "done"].shape[1] @@ -119,12 +150,7 @@ def eval_policy( return info -@hydra.main(version_base=None, config_name="default", config_path="../configs") -def eval_cli(cfg: dict): - eval(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) - - -def eval(cfg: dict, out_dir=None): +def eval(cfg: dict, out_dir=None, stats_path=None): if out_dir is None: raise NotImplementedError() @@ -139,10 +165,10 @@ def eval(cfg: dict, out_dir=None): log_output_dir(out_dir) - logging.info("make_offline_buffer") - offline_buffer = make_offline_buffer(cfg) + logging.info("Making transforms.") + offline_buffer = make_offline_buffer(cfg, stats_path=stats_path) - logging.info("make_env") + logging.info("Making environment.") env = make_env(cfg, transform=offline_buffer.transform) if cfg.policy.pretrained_model_path: @@ -170,5 +196,52 @@ def eval(cfg: dict, out_dir=None): logging.info("End of eval") +def _relative_path_between(path1: Path, path2: Path) -> Path: + """Returns path1 relative to path2.""" + path1 = path1.absolute() + path2 = path2.absolute() + try: + return path1.relative_to(path2) + except ValueError: # most likely because path1 is not a subpath of path2 + common_parts = Path(osp.commonpath([path1, path2])).parts + return Path( + "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) + ) + + if __name__ == "__main__": - eval_cli() + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--config", help="Path to a specific yaml config you want to use.") + group.add_argument("--hub-id", help="HuggingFace Hub ID for a pretrained model.") + parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.") + parser.add_argument( + "overrides", + nargs="*", + help="Any key=value arguments to override config values (use dots for.nested=overrides)", + ) + args = parser.parse_args() + + if args.config is not None: + # Note: For the config_path, Hydra wants a path relative to this script file. + hydra.initialize( + config_path=str( + _relative_path_between(Path(args.config).absolute().parent, Path(__file__).parent) + ) + ) + cfg = hydra.compose(Path(args.config).stem, args.overrides) + # TODO(alexander-soare): Save and load stats in trained model directory. + stats_path = None + elif args.hub_id is not None: + folder = Path(snapshot_download(args.hub_id, revision="v1.0")) + cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent))) + cfg = hydra.compose("config", args.overrides) + cfg.policy.pretrained_model_path = folder / "model.pt" + stats_path = folder / "stats.pth" + + eval( + cfg, + out_dir=f"eval_outputs/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}", + ) From 3f0f95f4c02162ff70743b2ba162d429b683d84f Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 22 Mar 2024 10:34:22 +0000 Subject: [PATCH 2/5] update readme --- README.md | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 076252c8..9bc8e56e 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,6 @@ wandb login ## Usage - ### Train ``` @@ -65,14 +64,9 @@ hydra.run.dir=tmp/$(date +"%Y_%m_%d") \ env=pusht ``` -### Visualize online buffer / Eval - -``` -python lerobot/scripts/eval.py \ -hydra.run.dir=tmp/$(date +"%Y_%m_%d") \ -env=pusht -``` +### Eval +Run `python lerobot/scripts/eval.py --help` for instructions. ## TODO @@ -106,8 +100,9 @@ with profile( ```bash python lerobot/scripts/eval.py \ -pretrained_model_path=/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt \ -eval_episodes=7 + --config /home/rcadene/code/fowm/logs/xarm_lift/all/default/2/.hydra/config.yaml \ + pretrained_model_path=/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt \ + eval_episodes=7 ``` ## Contribute From 1b279a1fc0394c9cbd76b6d7c0ba6d5b690cce3c Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 22 Mar 2024 10:53:34 +0000 Subject: [PATCH 3/5] fix test --- .github/workflows/test.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 000777dc..728a9786 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -135,8 +135,7 @@ jobs: run: | source .venv/bin/activate python lerobot/scripts/eval.py \ - hydra.job.name=pusht \ - env=pusht \ + --config tests/outputs/.hydra/config.yaml \ wandb.enable=False \ eval_episodes=1 \ env.episode_length=8 \ From 529f42643dd2930fbaa1150a954f860d7de003fe Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 22 Mar 2024 12:33:25 +0000 Subject: [PATCH 4/5] revision --- .gitignore | 1 - lerobot/configs/default.yaml | 2 +- lerobot/scripts/eval.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index f07ce341..ad9892d4 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,6 @@ tmp wandb data outputs -eval_outputs .vscode rl diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 52fd1d60..2dc313e4 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -5,7 +5,7 @@ defaults: hydra: run: - dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name} + dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name} job: name: default diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 480fca59..2b5611c1 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -243,5 +243,5 @@ if __name__ == "__main__": eval( cfg, - out_dir=f"eval_outputs/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}", + out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}", ) From 115927d0f66cef9d383f0dccc5e09711b2317e50 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 22 Mar 2024 12:58:59 +0000 Subject: [PATCH 5/5] make sure to pass stats.pth arg --- lerobot/scripts/eval.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 2b5611c1..6ff25562 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -244,4 +244,5 @@ if __name__ == "__main__": eval( cfg, out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}", + stats_path=stats_path, )