backup wip
This commit is contained in:
parent
c643a1ba7f
commit
361eb007d7
|
@ -2,9 +2,11 @@ import logging
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
|
@ -24,7 +26,7 @@ def cfg_to_group(cfg, return_list=False):
|
|||
class Logger:
|
||||
"""Primary logger object. Logs either locally or using wandb."""
|
||||
|
||||
def __init__(self, log_dir, job_name, cfg):
|
||||
def __init__(self, log_dir, job_name, cfg: DictConfig, stats: dict[str, Tensor]):
|
||||
self._log_dir = Path(log_dir)
|
||||
self._log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._job_name = job_name
|
||||
|
@ -36,6 +38,7 @@ class Logger:
|
|||
self._group = cfg_to_group(cfg)
|
||||
self._seed = cfg.seed
|
||||
self._cfg = cfg
|
||||
self._stats = stats
|
||||
self._eval = []
|
||||
project = cfg.get("wandb", {}).get("project")
|
||||
entity = cfg.get("wandb", {}).get("entity")
|
||||
|
@ -79,6 +82,8 @@ class Logger:
|
|||
else:
|
||||
save_dir = self._model_dir / str(identifier)
|
||||
policy.save_pretrained(save_dir)
|
||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||
torch.save(self._stats, save_dir / "stats.pth")
|
||||
model_path = save_dir / SAFETENSORS_SINGLE_FILE
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" in its name
|
||||
|
|
|
@ -5,7 +5,17 @@ from omegaconf import DictConfig, OmegaConf
|
|||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
|
||||
def make_policy(cfg: DictConfig):
|
||||
def make_policy(cfg: DictConfig | None = None, pretrained_policy_name_or_path: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
cfg: Hydra configuration.
|
||||
pretrained_policy_name_or_path: Hugging Face hub ID (repository name), or path to a local folder with
|
||||
the policy weights and configuration.
|
||||
|
||||
TODO(alexander-soare): This function is currently in a transitional state where we are using both Hydra
|
||||
configurations and policy dataclass configurations. We will remove the Hydra configuration from this file
|
||||
once all models are et up to use dataclass configurations.
|
||||
"""
|
||||
if cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
|
@ -31,6 +41,7 @@ def make_policy(cfg: DictConfig):
|
|||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
|
||||
if pretrained_policy_name_or_path is None:
|
||||
expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters)
|
||||
assert set(cfg.policy).issuperset(
|
||||
expected_kwargs
|
||||
|
@ -43,11 +54,13 @@ def make_policy(cfg: DictConfig):
|
|||
}
|
||||
)
|
||||
policy = ActionChunkingTransformerPolicy(policy_cfg)
|
||||
else:
|
||||
policy = ActionChunkingTransformerPolicy.from_pretrained(pretrained_policy_name_or_path)
|
||||
policy.to(get_safe_torch_device(cfg.device))
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
if cfg.policy.pretrained_model_path and pretrained_policy_name_or_path is None:
|
||||
# TODO(rcadene): hack for old pretrained models from fowm
|
||||
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
||||
if "offline" in cfg.policy.pretrained_model_path:
|
||||
|
|
|
@ -291,30 +291,45 @@ def eval_policy(
|
|||
return info
|
||||
|
||||
|
||||
def eval(cfg: DictConfig, out_dir=None, stats_path=None):
|
||||
def eval(hydra_cfg: DictConfig, pretrained_policy_name_or_path: str, out_dir=None, stats_path=None):
|
||||
"""Evaluate a policy.
|
||||
|
||||
Args:
|
||||
hydra_cfg: Hydra config.
|
||||
pretrained_policy_name_or_path: Hugging Face hub ID (repository name), or path to a local folder with
|
||||
the policy weights and configuration.
|
||||
out_dir: The directory to save the evaluation results (JSON file and videos)
|
||||
stats_path: The path to the stats file.
|
||||
|
||||
TODO(alexander-soare): This function is currently in a transitional state where we are using both Hydra
|
||||
configurations and policy dataclass configurations. This is because we still need the Hydra config to get
|
||||
the transforms and environment. The transforms should be absorbed into the policy configuration, and the
|
||||
environment configuration should be standalone.
|
||||
"""
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
init_logging()
|
||||
|
||||
# Check device is available
|
||||
get_safe_torch_device(cfg.device, log=True)
|
||||
get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(cfg.seed)
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("Making transforms.")
|
||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||
transform = make_dataset(cfg, stats_path=stats_path).transform
|
||||
transform = make_dataset(hydra_cfg, stats_path=stats_path).transform
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
env = make_env(hydra_cfg, num_parallel_envs=hydra_cfg.eval_episodes)
|
||||
|
||||
logging.info("Making policy.")
|
||||
policy = make_policy(cfg)
|
||||
|
||||
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path)
|
||||
|
||||
info = eval_policy(
|
||||
env,
|
||||
|
@ -322,7 +337,7 @@ def eval(cfg: DictConfig, out_dir=None, stats_path=None):
|
|||
max_episodes_rendered=10,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
transform=transform,
|
||||
seed=cfg.seed,
|
||||
seed=hydra_cfg.seed,
|
||||
)
|
||||
print(info["aggregated"])
|
||||
|
||||
|
@ -340,9 +355,9 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument("--config", help="Path to a specific yaml config you want to use.")
|
||||
group.add_argument("--hub-id", help="HuggingFace Hub ID for a pretrained model.")
|
||||
parser.add_argument(
|
||||
"-p", "--policy-name-or-path", help="HuggingFace Hub ID, or path to a pretrained model."
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"overrides",
|
||||
|
@ -351,20 +366,17 @@ if __name__ == "__main__":
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.config is not None:
|
||||
# Note: For the config_path, Hydra wants a path relative to this script file.
|
||||
cfg = init_hydra_config(args.config, args.overrides)
|
||||
# TODO(alexander-soare): Save and load stats in trained model directory.
|
||||
stats_path = None
|
||||
elif args.hub_id is not None:
|
||||
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
||||
cfg = init_hydra_config(
|
||||
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
||||
path = Path(args.policy_name_or_path)
|
||||
if not path.exists():
|
||||
path = Path(snapshot_download(args.policy_name_or_path, revision=args.revision))
|
||||
hydra_cfg = init_hydra_config(
|
||||
path / "config.yaml", [f"policy.pretrained_model_path={path / 'model.pt'}", *args.overrides]
|
||||
)
|
||||
stats_path = folder / "stats.pth"
|
||||
stats_path = path / "stats.pth"
|
||||
|
||||
eval(
|
||||
cfg,
|
||||
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
|
||||
hydra_cfg,
|
||||
args.policy_name_or_path,
|
||||
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}",
|
||||
stats_path=stats_path,
|
||||
)
|
||||
|
|
|
@ -196,7 +196,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(out_dir, job_name, cfg)
|
||||
logger = Logger(out_dir, job_name, cfg, offline_dataset.transform.transforms[1].stats)
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
|
|
Loading…
Reference in New Issue