-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d8aa58c
commit d5a86ac
Showing
30 changed files
with
104,482 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
models | ||
__pycache__ | ||
env | ||
env | ||
|
||
build | ||
dist |
File renamed without changes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import hashlib | ||
import io | ||
import os | ||
import urllib | ||
import warnings | ||
from typing import List, Optional, Union | ||
|
||
import torch | ||
from tqdm import tqdm | ||
|
||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim | ||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language | ||
from .model import Whisper, ModelDimensions | ||
from .transcribe import transcribe | ||
from .version import __version__ | ||
|
||
|
||
_MODELS = { | ||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", | ||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", | ||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", | ||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", | ||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", | ||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", | ||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", | ||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", | ||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", | ||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", | ||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", | ||
} | ||
|
||
|
||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: | ||
os.makedirs(root, exist_ok=True) | ||
|
||
expected_sha256 = url.split("/")[-2] | ||
download_target = os.path.join(root, os.path.basename(url)) | ||
|
||
if os.path.exists(download_target) and not os.path.isfile(download_target): | ||
raise RuntimeError(f"{download_target} exists and is not a regular file") | ||
|
||
if os.path.isfile(download_target): | ||
with open(download_target, "rb") as f: | ||
model_bytes = f.read() | ||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: | ||
return model_bytes if in_memory else download_target | ||
else: | ||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") | ||
|
||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: | ||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: | ||
while True: | ||
buffer = source.read(8192) | ||
if not buffer: | ||
break | ||
|
||
output.write(buffer) | ||
loop.update(len(buffer)) | ||
|
||
model_bytes = open(download_target, "rb").read() | ||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: | ||
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.") | ||
|
||
return model_bytes if in_memory else download_target | ||
|
||
|
||
def available_models() -> List[str]: | ||
"""Returns the names of available models""" | ||
return list(_MODELS.keys()) | ||
|
||
|
||
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper: | ||
""" | ||
Load a Whisper ASR model | ||
Parameters | ||
---------- | ||
name : str | ||
one of the official model names listed by `whisper.available_models()`, or | ||
path to a model checkpoint containing the model dimensions and the model state_dict. | ||
device : Union[str, torch.device] | ||
the PyTorch device to put the model into | ||
download_root: str | ||
path to download the model files; by default, it uses "~/.cache/whisper" | ||
in_memory: bool | ||
whether to preload the model weights into host memory | ||
Returns | ||
------- | ||
model : Whisper | ||
The Whisper ASR model instance | ||
""" | ||
|
||
if device is None: | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
if download_root is None: | ||
download_root = os.path.join( | ||
os.getenv( | ||
"XDG_CACHE_HOME", | ||
os.path.join( | ||
os.path.expanduser("~"), ".cache" | ||
) | ||
), | ||
"whisper" | ||
) | ||
|
||
if name in _MODELS: | ||
checkpoint_file = _download(_MODELS[name], download_root, in_memory) | ||
elif os.path.isfile(name): | ||
checkpoint_file = open(name, "rb").read() if in_memory else name | ||
else: | ||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}") | ||
|
||
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp: | ||
checkpoint = torch.load(fp, map_location=device) | ||
del checkpoint_file | ||
|
||
dims = ModelDimensions(**checkpoint["dims"]) | ||
model = Whisper(dims) | ||
model.load_state_dict(checkpoint["model_state_dict"]) | ||
|
||
return model.to(device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .transcribe import cli | ||
|
||
|
||
cli() |
Oops, something went wrong.