backup wip

This commit is contained in:
Alexander Soare 2024-04-15 16:39:04 +01:00
parent c643a1ba7f
commit 361eb007d7
4 changed files with 70 additions and 40 deletions

View File

@ -2,9 +2,11 @@ import logging
import os import os
from pathlib import Path from pathlib import Path
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
from torch import Tensor
def log_output_dir(out_dir): def log_output_dir(out_dir):
@ -24,7 +26,7 @@ def cfg_to_group(cfg, return_list=False):
class Logger: class Logger:
"""Primary logger object. Logs either locally or using wandb.""" """Primary logger object. Logs either locally or using wandb."""
def __init__(self, log_dir, job_name, cfg): def __init__(self, log_dir, job_name, cfg: DictConfig, stats: dict[str, Tensor]):
self._log_dir = Path(log_dir) self._log_dir = Path(log_dir)
self._log_dir.mkdir(parents=True, exist_ok=True) self._log_dir.mkdir(parents=True, exist_ok=True)
self._job_name = job_name self._job_name = job_name
@ -36,6 +38,7 @@ class Logger:
self._group = cfg_to_group(cfg) self._group = cfg_to_group(cfg)
self._seed = cfg.seed self._seed = cfg.seed
self._cfg = cfg self._cfg = cfg
self._stats = stats
self._eval = [] self._eval = []
project = cfg.get("wandb", {}).get("project") project = cfg.get("wandb", {}).get("project")
entity = cfg.get("wandb", {}).get("entity") entity = cfg.get("wandb", {}).get("entity")
@ -79,6 +82,8 @@ class Logger:
else: else:
save_dir = self._model_dir / str(identifier) save_dir = self._model_dir / str(identifier)
policy.save_pretrained(save_dir) policy.save_pretrained(save_dir)
OmegaConf.save(self._cfg, save_dir / "config.yaml")
torch.save(self._stats, save_dir / "stats.pth")
model_path = save_dir / SAFETENSORS_SINGLE_FILE model_path = save_dir / SAFETENSORS_SINGLE_FILE
if self._wandb and not self._disable_wandb_artifact: if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name # note wandb artifact does not accept ":" in its name

View File

@ -5,7 +5,17 @@ from omegaconf import DictConfig, OmegaConf
from lerobot.common.utils import get_safe_torch_device from lerobot.common.utils import get_safe_torch_device
def make_policy(cfg: DictConfig): def make_policy(cfg: DictConfig | None = None, pretrained_policy_name_or_path: str | None = None):
"""
Args:
cfg: Hydra configuration.
pretrained_policy_name_or_path: Hugging Face hub ID (repository name), or path to a local folder with
the policy weights and configuration.
TODO(alexander-soare): This function is currently in a transitional state where we are using both Hydra
configurations and policy dataclass configurations. We will remove the Hydra configuration from this file
once all models are et up to use dataclass configurations.
"""
if cfg.policy.name == "tdmpc": if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
@ -31,23 +41,26 @@ def make_policy(cfg: DictConfig):
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters) if pretrained_policy_name_or_path is None:
assert set(cfg.policy).issuperset( expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters)
expected_kwargs assert set(cfg.policy).issuperset(
), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}" expected_kwargs
policy_cfg = ActionChunkingTransformerConfig( ), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}"
**{ policy_cfg = ActionChunkingTransformerConfig(
k: v **{
for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items() k: v
if k in expected_kwargs for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items()
} if k in expected_kwargs
) }
policy = ActionChunkingTransformerPolicy(policy_cfg) )
policy = ActionChunkingTransformerPolicy(policy_cfg)
else:
policy = ActionChunkingTransformerPolicy.from_pretrained(pretrained_policy_name_or_path)
policy.to(get_safe_torch_device(cfg.device)) policy.to(get_safe_torch_device(cfg.device))
else: else:
raise ValueError(cfg.policy.name) raise ValueError(cfg.policy.name)
if cfg.policy.pretrained_model_path: if cfg.policy.pretrained_model_path and pretrained_policy_name_or_path is None:
# TODO(rcadene): hack for old pretrained models from fowm # TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path: if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.policy.pretrained_model_path: if "offline" in cfg.policy.pretrained_model_path:

View File

@ -291,30 +291,45 @@ def eval_policy(
return info return info
def eval(cfg: DictConfig, out_dir=None, stats_path=None): def eval(hydra_cfg: DictConfig, pretrained_policy_name_or_path: str, out_dir=None, stats_path=None):
"""Evaluate a policy.
Args:
hydra_cfg: Hydra config.
pretrained_policy_name_or_path: Hugging Face hub ID (repository name), or path to a local folder with
the policy weights and configuration.
out_dir: The directory to save the evaluation results (JSON file and videos)
stats_path: The path to the stats file.
TODO(alexander-soare): This function is currently in a transitional state where we are using both Hydra
configurations and policy dataclass configurations. This is because we still need the Hydra config to get
the transforms and environment. The transforms should be absorbed into the policy configuration, and the
environment configuration should be standalone.
"""
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
init_logging() init_logging()
# Check device is available # Check device is available
get_safe_torch_device(cfg.device, log=True) get_safe_torch_device(hydra_cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed) set_global_seed(hydra_cfg.seed)
log_output_dir(out_dir) log_output_dir(out_dir)
logging.info("Making transforms.") logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation. # TODO(alexander-soare): Completely decouple datasets from evaluation.
transform = make_dataset(cfg, stats_path=stats_path).transform transform = make_dataset(hydra_cfg, stats_path=stats_path).transform
logging.info("Making environment.") logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) env = make_env(hydra_cfg, num_parallel_envs=hydra_cfg.eval_episodes)
logging.info("Making policy.") logging.info("Making policy.")
policy = make_policy(cfg)
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path)
info = eval_policy( info = eval_policy(
env, env,
@ -322,7 +337,7 @@ def eval(cfg: DictConfig, out_dir=None, stats_path=None):
max_episodes_rendered=10, max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
transform=transform, transform=transform,
seed=cfg.seed, seed=hydra_cfg.seed,
) )
print(info["aggregated"]) print(info["aggregated"])
@ -340,9 +355,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
) )
group = parser.add_mutually_exclusive_group(required=True) parser.add_argument(
group.add_argument("--config", help="Path to a specific yaml config you want to use.") "-p", "--policy-name-or-path", help="HuggingFace Hub ID, or path to a pretrained model."
group.add_argument("--hub-id", help="HuggingFace Hub ID for a pretrained model.") )
parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.") parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.")
parser.add_argument( parser.add_argument(
"overrides", "overrides",
@ -351,20 +366,17 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
if args.config is not None: path = Path(args.policy_name_or_path)
# Note: For the config_path, Hydra wants a path relative to this script file. if not path.exists():
cfg = init_hydra_config(args.config, args.overrides) path = Path(snapshot_download(args.policy_name_or_path, revision=args.revision))
# TODO(alexander-soare): Save and load stats in trained model directory. hydra_cfg = init_hydra_config(
stats_path = None path / "config.yaml", [f"policy.pretrained_model_path={path / 'model.pt'}", *args.overrides]
elif args.hub_id is not None: )
folder = Path(snapshot_download(args.hub_id, revision=args.revision)) stats_path = path / "stats.pth"
cfg = init_hydra_config(
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
)
stats_path = folder / "stats.pth"
eval( eval(
cfg, hydra_cfg,
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}", args.policy_name_or_path,
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}",
stats_path=stats_path, stats_path=stats_path,
) )

View File

@ -196,7 +196,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
# log metrics to terminal and wandb # log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg) logger = Logger(out_dir, job_name, cfg, offline_dataset.transform.transforms[1].stats)
log_output_dir(out_dir) log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.env.task=}")