backup wip

This commit is contained in:
Alexander Soare 2024-04-15 16:39:04 +01:00
parent c643a1ba7f
commit 361eb007d7
4 changed files with 70 additions and 40 deletions

View File

@ -2,9 +2,11 @@ import logging
import os
from pathlib import Path
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import OmegaConf
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import Tensor
def log_output_dir(out_dir):
@ -24,7 +26,7 @@ def cfg_to_group(cfg, return_list=False):
class Logger:
"""Primary logger object. Logs either locally or using wandb."""
def __init__(self, log_dir, job_name, cfg):
def __init__(self, log_dir, job_name, cfg: DictConfig, stats: dict[str, Tensor]):
self._log_dir = Path(log_dir)
self._log_dir.mkdir(parents=True, exist_ok=True)
self._job_name = job_name
@ -36,6 +38,7 @@ class Logger:
self._group = cfg_to_group(cfg)
self._seed = cfg.seed
self._cfg = cfg
self._stats = stats
self._eval = []
project = cfg.get("wandb", {}).get("project")
entity = cfg.get("wandb", {}).get("entity")
@ -79,6 +82,8 @@ class Logger:
else:
save_dir = self._model_dir / str(identifier)
policy.save_pretrained(save_dir)
OmegaConf.save(self._cfg, save_dir / "config.yaml")
torch.save(self._stats, save_dir / "stats.pth")
model_path = save_dir / SAFETENSORS_SINGLE_FILE
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name

View File

@ -5,7 +5,17 @@ from omegaconf import DictConfig, OmegaConf
from lerobot.common.utils import get_safe_torch_device
def make_policy(cfg: DictConfig):
def make_policy(cfg: DictConfig | None = None, pretrained_policy_name_or_path: str | None = None):
"""
Args:
cfg: Hydra configuration.
pretrained_policy_name_or_path: Hugging Face hub ID (repository name), or path to a local folder with
the policy weights and configuration.
TODO(alexander-soare): This function is currently in a transitional state where we are using both Hydra
configurations and policy dataclass configurations. We will remove the Hydra configuration from this file
once all models are et up to use dataclass configurations.
"""
if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
@ -31,6 +41,7 @@ def make_policy(cfg: DictConfig):
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
if pretrained_policy_name_or_path is None:
expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters)
assert set(cfg.policy).issuperset(
expected_kwargs
@ -43,11 +54,13 @@ def make_policy(cfg: DictConfig):
}
)
policy = ActionChunkingTransformerPolicy(policy_cfg)
else:
policy = ActionChunkingTransformerPolicy.from_pretrained(pretrained_policy_name_or_path)
policy.to(get_safe_torch_device(cfg.device))
else:
raise ValueError(cfg.policy.name)
if cfg.policy.pretrained_model_path:
if cfg.policy.pretrained_model_path and pretrained_policy_name_or_path is None:
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.policy.pretrained_model_path:

View File

@ -291,30 +291,45 @@ def eval_policy(
return info
def eval(cfg: DictConfig, out_dir=None, stats_path=None):
def eval(hydra_cfg: DictConfig, pretrained_policy_name_or_path: str, out_dir=None, stats_path=None):
"""Evaluate a policy.
Args:
hydra_cfg: Hydra config.
pretrained_policy_name_or_path: Hugging Face hub ID (repository name), or path to a local folder with
the policy weights and configuration.
out_dir: The directory to save the evaluation results (JSON file and videos)
stats_path: The path to the stats file.
TODO(alexander-soare): This function is currently in a transitional state where we are using both Hydra
configurations and policy dataclass configurations. This is because we still need the Hydra config to get
the transforms and environment. The transforms should be absorbed into the policy configuration, and the
environment configuration should be standalone.
"""
if out_dir is None:
raise NotImplementedError()
init_logging()
# Check device is available
get_safe_torch_device(cfg.device, log=True)
get_safe_torch_device(hydra_cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed)
set_global_seed(hydra_cfg.seed)
log_output_dir(out_dir)
logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation.
transform = make_dataset(cfg, stats_path=stats_path).transform
transform = make_dataset(hydra_cfg, stats_path=stats_path).transform
logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
env = make_env(hydra_cfg, num_parallel_envs=hydra_cfg.eval_episodes)
logging.info("Making policy.")
policy = make_policy(cfg)
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path)
info = eval_policy(
env,
@ -322,7 +337,7 @@ def eval(cfg: DictConfig, out_dir=None, stats_path=None):
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
transform=transform,
seed=cfg.seed,
seed=hydra_cfg.seed,
)
print(info["aggregated"])
@ -340,9 +355,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--config", help="Path to a specific yaml config you want to use.")
group.add_argument("--hub-id", help="HuggingFace Hub ID for a pretrained model.")
parser.add_argument(
"-p", "--policy-name-or-path", help="HuggingFace Hub ID, or path to a pretrained model."
)
parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.")
parser.add_argument(
"overrides",
@ -351,20 +366,17 @@ if __name__ == "__main__":
)
args = parser.parse_args()
if args.config is not None:
# Note: For the config_path, Hydra wants a path relative to this script file.
cfg = init_hydra_config(args.config, args.overrides)
# TODO(alexander-soare): Save and load stats in trained model directory.
stats_path = None
elif args.hub_id is not None:
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
cfg = init_hydra_config(
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
path = Path(args.policy_name_or_path)
if not path.exists():
path = Path(snapshot_download(args.policy_name_or_path, revision=args.revision))
hydra_cfg = init_hydra_config(
path / "config.yaml", [f"policy.pretrained_model_path={path / 'model.pt'}", *args.overrides]
)
stats_path = folder / "stats.pth"
stats_path = path / "stats.pth"
eval(
cfg,
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
hydra_cfg,
args.policy_name_or_path,
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}",
stats_path=stats_path,
)

View File

@ -196,7 +196,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_total_params = sum(p.numel() for p in policy.parameters())
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
logger = Logger(out_dir, job_name, cfg, offline_dataset.transform.transforms[1].stats)
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")