mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-09-23 17:52:05 +02:00
82 lines
3.5 KiB
Python
82 lines
3.5 KiB
Python
"""Set up audio transcription models based on model size."""
|
|
|
|
import logging
|
|
import os
|
|
|
|
import sherpa_onnx
|
|
from faster_whisper.utils import download_model
|
|
|
|
from frigate.comms.inter_process import InterProcessRequestor
|
|
from frigate.const import MODEL_CACHE_DIR
|
|
from frigate.data_processing.types import AudioTranscriptionModel
|
|
from frigate.util.downloader import ModelDownloader
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AudioTranscriptionModelRunner:
|
|
def __init__(
|
|
self,
|
|
device: str = "CPU",
|
|
model_size: str = "small",
|
|
):
|
|
self.model: AudioTranscriptionModel = None
|
|
self.requestor = InterProcessRequestor()
|
|
|
|
if model_size == "large":
|
|
# use the Whisper download function instead of our own
|
|
logger.debug("Downloading Whisper audio transcription model")
|
|
download_model(
|
|
size_or_id="small" if device == "cuda" else "tiny",
|
|
local_files_only=False,
|
|
cache_dir=os.path.join(MODEL_CACHE_DIR, "whisper"),
|
|
)
|
|
logger.debug("Whisper audio transcription model downloaded")
|
|
|
|
else:
|
|
# small model as default
|
|
download_path = os.path.join(MODEL_CACHE_DIR, "sherpa-onnx")
|
|
HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
|
self.model_files = {
|
|
"encoder.onnx": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/encoder-epoch-99-avg-1-chunk-16-left-128.onnx",
|
|
"decoder.onnx": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/decoder-epoch-99-avg-1-chunk-16-left-128.onnx",
|
|
"joiner.onnx": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/joiner-epoch-99-avg-1-chunk-16-left-128.onnx",
|
|
"tokens.txt": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/tokens.txt",
|
|
}
|
|
|
|
if not all(
|
|
os.path.exists(os.path.join(download_path, n))
|
|
for n in self.model_files.keys()
|
|
):
|
|
self.downloader = ModelDownloader(
|
|
model_name="sherpa-onnx",
|
|
download_path=download_path,
|
|
file_names=self.model_files.keys(),
|
|
download_func=self.__download_models,
|
|
)
|
|
self.downloader.ensure_model_files()
|
|
self.downloader.wait_for_download()
|
|
|
|
self.model = sherpa_onnx.OnlineRecognizer.from_transducer(
|
|
tokens=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/tokens.txt"),
|
|
encoder=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/encoder.onnx"),
|
|
decoder=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/decoder.onnx"),
|
|
joiner=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/joiner.onnx"),
|
|
num_threads=2,
|
|
sample_rate=16000,
|
|
feature_dim=80,
|
|
enable_endpoint_detection=True,
|
|
rule1_min_trailing_silence=2.4,
|
|
rule2_min_trailing_silence=1.2,
|
|
rule3_min_utterance_length=300,
|
|
decoding_method="greedy_search",
|
|
provider="cpu",
|
|
)
|
|
|
|
def __download_models(self, path: str) -> None:
|
|
try:
|
|
file_name = os.path.basename(path)
|
|
ModelDownloader.download_from_url(self.model_files[file_name], path)
|
|
except Exception as e:
|
|
logger.error(f"Failed to download {path}: {e}")
|