173 lines
5.9 KiB
Python
173 lines
5.9 KiB
Python
"""
|
|
Routines for loading DeepSpeech model.
|
|
"""
|
|
|
|
__all__ = ['get_deepspeech_model_file']
|
|
|
|
import os
|
|
import zipfile
|
|
import logging
|
|
import hashlib
|
|
|
|
|
|
deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features'
|
|
|
|
|
|
def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")):
|
|
"""
|
|
Return location for the pretrained on local file system. This function will download from online model zoo when
|
|
model cannot be found or has mismatch. The root directory will be created if it doesn't exist.
|
|
|
|
Parameters
|
|
----------
|
|
local_model_store_dir_path : str, default $TENSORFLOW_HOME/models
|
|
Location for keeping the model parameters.
|
|
|
|
Returns
|
|
-------
|
|
file_path
|
|
Path to the requested pretrained model file.
|
|
"""
|
|
sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e"
|
|
file_name = "deepspeech-0_1_0-b90017e8.pb"
|
|
local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)
|
|
file_path = os.path.join(local_model_store_dir_path, file_name)
|
|
if os.path.exists(file_path):
|
|
if _check_sha1(file_path, sha1_hash):
|
|
return file_path
|
|
else:
|
|
logging.warning("Mismatch in the content of model file detected. Downloading again.")
|
|
else:
|
|
logging.info("Model file not found. Downloading to {}.".format(file_path))
|
|
|
|
if not os.path.exists(local_model_store_dir_path):
|
|
os.makedirs(local_model_store_dir_path)
|
|
|
|
zip_file_path = file_path + ".zip"
|
|
_download(
|
|
url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format(
|
|
repo_url=deepspeech_features_repo_url,
|
|
repo_release_tag="v0.0.1",
|
|
file_name=file_name),
|
|
path=zip_file_path,
|
|
overwrite=True)
|
|
with zipfile.ZipFile(zip_file_path) as zf:
|
|
zf.extractall(local_model_store_dir_path)
|
|
os.remove(zip_file_path)
|
|
|
|
if _check_sha1(file_path, sha1_hash):
|
|
return file_path
|
|
else:
|
|
raise ValueError("Downloaded file has different hash. Please try again.")
|
|
|
|
|
|
def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
|
|
"""
|
|
Download an given URL
|
|
|
|
Parameters
|
|
----------
|
|
url : str
|
|
URL to download
|
|
path : str, optional
|
|
Destination path to store downloaded file. By default stores to the
|
|
current directory with same name as in url.
|
|
overwrite : bool, optional
|
|
Whether to overwrite destination file if already exists.
|
|
sha1_hash : str, optional
|
|
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
|
|
but doesn't match.
|
|
retries : integer, default 5
|
|
The number of times to attempt the download in case of failure or non 200 return codes
|
|
verify_ssl : bool, default True
|
|
Verify SSL certificates.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The file path of the downloaded file.
|
|
"""
|
|
import warnings
|
|
try:
|
|
import requests
|
|
except ImportError:
|
|
class requests_failed_to_import(object):
|
|
pass
|
|
requests = requests_failed_to_import
|
|
|
|
if path is None:
|
|
fname = url.split("/")[-1]
|
|
# Empty filenames are invalid
|
|
assert fname, "Can't construct file-name from this URL. Please set the `path` option manually."
|
|
else:
|
|
path = os.path.expanduser(path)
|
|
if os.path.isdir(path):
|
|
fname = os.path.join(path, url.split("/")[-1])
|
|
else:
|
|
fname = path
|
|
assert retries >= 0, "Number of retries should be at least 0"
|
|
|
|
if not verify_ssl:
|
|
warnings.warn(
|
|
"Unverified HTTPS request is being made (verify_ssl=False). "
|
|
"Adding certificate verification is strongly advised.")
|
|
|
|
if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):
|
|
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
|
|
if not os.path.exists(dirname):
|
|
os.makedirs(dirname)
|
|
while retries + 1 > 0:
|
|
# Disable pyling too broad Exception
|
|
# pylint: disable=W0703
|
|
try:
|
|
print("Downloading {} from {}...".format(fname, url))
|
|
r = requests.get(url, stream=True, verify=verify_ssl)
|
|
if r.status_code != 200:
|
|
raise RuntimeError("Failed downloading url {}".format(url))
|
|
with open(fname, "wb") as f:
|
|
for chunk in r.iter_content(chunk_size=1024):
|
|
if chunk: # filter out keep-alive new chunks
|
|
f.write(chunk)
|
|
if sha1_hash and not _check_sha1(fname, sha1_hash):
|
|
raise UserWarning("File {} is downloaded but the content hash does not match."
|
|
" The repo may be outdated or download may be incomplete. "
|
|
"If the `repo_url` is overridden, consider switching to "
|
|
"the default repo.".format(fname))
|
|
break
|
|
except Exception as e:
|
|
retries -= 1
|
|
if retries <= 0:
|
|
raise e
|
|
else:
|
|
print("download failed, retrying, {} attempt{} left"
|
|
.format(retries, "s" if retries > 1 else ""))
|
|
|
|
return fname
|
|
|
|
|
|
def _check_sha1(filename, sha1_hash):
|
|
"""
|
|
Check whether the sha1 hash of the file content matches the expected hash.
|
|
|
|
Parameters
|
|
----------
|
|
filename : str
|
|
Path to the file.
|
|
sha1_hash : str
|
|
Expected sha1 hash in hexadecimal digits.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
Whether the file content matches the expected hash.
|
|
"""
|
|
sha1 = hashlib.sha1()
|
|
with open(filename, "rb") as f:
|
|
while True:
|
|
data = f.read(1048576)
|
|
if not data:
|
|
break
|
|
sha1.update(data)
|
|
|
|
return sha1.hexdigest() == sha1_hash
|