From 361eb007d7cb3a345978db00e0bf7a0022486088 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 15 Apr 2024 16:39:04 +0100 Subject: [PATCH] backup wip --- lerobot/common/logger.py | 9 +++-- lerobot/common/policies/factory.py | 41 +++++++++++++-------- lerobot/scripts/eval.py | 58 ++++++++++++++++++------------ lerobot/scripts/train.py | 2 +- 4 files changed, 70 insertions(+), 40 deletions(-) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 059b6731..bf23d4e7 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -2,9 +2,11 @@ import logging import os from pathlib import Path +import torch from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from termcolor import colored +from torch import Tensor def log_output_dir(out_dir): @@ -24,7 +26,7 @@ def cfg_to_group(cfg, return_list=False): class Logger: """Primary logger object. Logs either locally or using wandb.""" - def __init__(self, log_dir, job_name, cfg): + def __init__(self, log_dir, job_name, cfg: DictConfig, stats: dict[str, Tensor]): self._log_dir = Path(log_dir) self._log_dir.mkdir(parents=True, exist_ok=True) self._job_name = job_name @@ -36,6 +38,7 @@ class Logger: self._group = cfg_to_group(cfg) self._seed = cfg.seed self._cfg = cfg + self._stats = stats self._eval = [] project = cfg.get("wandb", {}).get("project") entity = cfg.get("wandb", {}).get("entity") @@ -79,6 +82,8 @@ class Logger: else: save_dir = self._model_dir / str(identifier) policy.save_pretrained(save_dir) + OmegaConf.save(self._cfg, save_dir / "config.yaml") + torch.save(self._stats, save_dir / "stats.pth") model_path = save_dir / SAFETENSORS_SINGLE_FILE if self._wandb and not self._disable_wandb_artifact: # note wandb artifact does not accept ":" in its name diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index e021adf6..90b983af 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -5,7 +5,17 @@ from omegaconf import DictConfig, OmegaConf from lerobot.common.utils import get_safe_torch_device -def make_policy(cfg: DictConfig): +def make_policy(cfg: DictConfig | None = None, pretrained_policy_name_or_path: str | None = None): + """ + Args: + cfg: Hydra configuration. + pretrained_policy_name_or_path: Hugging Face hub ID (repository name), or path to a local folder with + the policy weights and configuration. + + TODO(alexander-soare): This function is currently in a transitional state where we are using both Hydra + configurations and policy dataclass configurations. We will remove the Hydra configuration from this file + once all models are et up to use dataclass configurations. + """ if cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPCPolicy @@ -31,23 +41,26 @@ def make_policy(cfg: DictConfig): from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy - expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters) - assert set(cfg.policy).issuperset( - expected_kwargs - ), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}" - policy_cfg = ActionChunkingTransformerConfig( - **{ - k: v - for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items() - if k in expected_kwargs - } - ) - policy = ActionChunkingTransformerPolicy(policy_cfg) + if pretrained_policy_name_or_path is None: + expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters) + assert set(cfg.policy).issuperset( + expected_kwargs + ), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}" + policy_cfg = ActionChunkingTransformerConfig( + **{ + k: v + for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items() + if k in expected_kwargs + } + ) + policy = ActionChunkingTransformerPolicy(policy_cfg) + else: + policy = ActionChunkingTransformerPolicy.from_pretrained(pretrained_policy_name_or_path) policy.to(get_safe_torch_device(cfg.device)) else: raise ValueError(cfg.policy.name) - if cfg.policy.pretrained_model_path: + if cfg.policy.pretrained_model_path and pretrained_policy_name_or_path is None: # TODO(rcadene): hack for old pretrained models from fowm if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path: if "offline" in cfg.policy.pretrained_model_path: diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 9e1b8637..5ec0aea1 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -291,30 +291,45 @@ def eval_policy( return info -def eval(cfg: DictConfig, out_dir=None, stats_path=None): +def eval(hydra_cfg: DictConfig, pretrained_policy_name_or_path: str, out_dir=None, stats_path=None): + """Evaluate a policy. + + Args: + hydra_cfg: Hydra config. + pretrained_policy_name_or_path: Hugging Face hub ID (repository name), or path to a local folder with + the policy weights and configuration. + out_dir: The directory to save the evaluation results (JSON file and videos) + stats_path: The path to the stats file. + + TODO(alexander-soare): This function is currently in a transitional state where we are using both Hydra + configurations and policy dataclass configurations. This is because we still need the Hydra config to get + the transforms and environment. The transforms should be absorbed into the policy configuration, and the + environment configuration should be standalone. + """ if out_dir is None: raise NotImplementedError() init_logging() # Check device is available - get_safe_torch_device(cfg.device, log=True) + get_safe_torch_device(hydra_cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - set_global_seed(cfg.seed) + set_global_seed(hydra_cfg.seed) log_output_dir(out_dir) logging.info("Making transforms.") # TODO(alexander-soare): Completely decouple datasets from evaluation. - transform = make_dataset(cfg, stats_path=stats_path).transform + transform = make_dataset(hydra_cfg, stats_path=stats_path).transform logging.info("Making environment.") - env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) + env = make_env(hydra_cfg, num_parallel_envs=hydra_cfg.eval_episodes) logging.info("Making policy.") - policy = make_policy(cfg) + + policy = make_policy(hydra_cfg, pretrained_policy_name_or_path) info = eval_policy( env, @@ -322,7 +337,7 @@ def eval(cfg: DictConfig, out_dir=None, stats_path=None): max_episodes_rendered=10, video_dir=Path(out_dir) / "eval", transform=transform, - seed=cfg.seed, + seed=hydra_cfg.seed, ) print(info["aggregated"]) @@ -340,9 +355,9 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) - group = parser.add_mutually_exclusive_group(required=True) - group.add_argument("--config", help="Path to a specific yaml config you want to use.") - group.add_argument("--hub-id", help="HuggingFace Hub ID for a pretrained model.") + parser.add_argument( + "-p", "--policy-name-or-path", help="HuggingFace Hub ID, or path to a pretrained model." + ) parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.") parser.add_argument( "overrides", @@ -351,20 +366,17 @@ if __name__ == "__main__": ) args = parser.parse_args() - if args.config is not None: - # Note: For the config_path, Hydra wants a path relative to this script file. - cfg = init_hydra_config(args.config, args.overrides) - # TODO(alexander-soare): Save and load stats in trained model directory. - stats_path = None - elif args.hub_id is not None: - folder = Path(snapshot_download(args.hub_id, revision=args.revision)) - cfg = init_hydra_config( - folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides] - ) - stats_path = folder / "stats.pth" + path = Path(args.policy_name_or_path) + if not path.exists(): + path = Path(snapshot_download(args.policy_name_or_path, revision=args.revision)) + hydra_cfg = init_hydra_config( + path / "config.yaml", [f"policy.pretrained_model_path={path / 'model.pt'}", *args.overrides] + ) + stats_path = path / "stats.pth" eval( - cfg, - out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}", + hydra_cfg, + args.policy_name_or_path, + out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}", stats_path=stats_path, ) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5ff6538d..42f6afd9 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -196,7 +196,7 @@ def train(cfg: dict, out_dir=None, job_name=None): num_total_params = sum(p.numel() for p in policy.parameters()) # log metrics to terminal and wandb - logger = Logger(out_dir, job_name, cfg) + logger = Logger(out_dir, job_name, cfg, offline_dataset.transform.transforms[1].stats) log_output_dir(out_dir) logging.info(f"{cfg.env.task=}")