fix main annot

This commit is contained in:
Wael Karkoub 2024-06-10 16:24:04 +01:00
parent 62f0e9bc54
commit 066c732c1c
1 changed files with 6 additions and 5 deletions

View File

@ -514,16 +514,17 @@ def _compile_episode_data(
def main(
pretrained_policy_path: str | None = None,
pretrained_policy_path: Path | None = None,
hydra_cfg_path: str | None = None,
out_dir: str | None = None,
config_overrides: list[str] | None = None,
):
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
if hydra_cfg_path is None:
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
if pretrained_policy_path is not None:
hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
else:
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
if out_dir is None:
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
@ -541,7 +542,7 @@ def main(
logging.info("Making policy.")
if hydra_cfg_path is None:
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
else:
# Note: We need the dataset stats to pass to the policy's normalization modules.
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
@ -635,7 +636,7 @@ if __name__ == "__main__":
)
main(
pretrained_policy_path=str(pretrained_policy_path),
pretrained_policy_path=pretrained_policy_path,
out_dir=args.out_dir,
config_overrides=args.overrides,
)