Fix pretrained_policy_path

This commit is contained in:
Simon Alibert 2024-05-04 22:19:52 +02:00
parent f23f5f977f
commit 7664ad8259
1 changed files with 12 additions and 11 deletions

View File

@ -583,17 +583,18 @@ if __name__ == "__main__":
pretrained_policy_path = Path( pretrained_policy_path = Path(
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision) snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
) )
except HFValidationError: except (HFValidationError, RepositoryNotFoundError) as e:
logging.warning( if isinstance(e, HFValidationError):
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID. " error_message = (
"Treating it as a local directory." "The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
) )
except RepositoryNotFoundError: else:
logging.warning( error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub. Treating " "The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
"it as a local directory." )
)
pretrained_policy_path = Path(args.pretrained_policy_name_or_path) logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists(): if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError( raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub " "The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "