mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-20 13:54:36 +01:00
Audio transcription tweaks (#18540)
* use model runner * unload whisper model when live transcription is complete
This commit is contained in:
committed by
Blake Blackshear
parent
b77e6f5ebc
commit
ac7fb29b32
81
frigate/data_processing/common/audio_transcription/model.py
Normal file
81
frigate/data_processing/common/audio_transcription/model.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Set up audio transcription models based on model size."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import sherpa_onnx
|
||||
from faster_whisper.utils import download_model
|
||||
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
from frigate.data_processing.types import AudioTranscriptionModel
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AudioTranscriptionModelRunner:
|
||||
def __init__(
|
||||
self,
|
||||
device: str = "CPU",
|
||||
model_size: str = "small",
|
||||
):
|
||||
self.model: AudioTranscriptionModel = None
|
||||
self.requestor = InterProcessRequestor()
|
||||
|
||||
if model_size == "large":
|
||||
# use the Whisper download function instead of our own
|
||||
logger.debug("Downloading Whisper audio transcription model")
|
||||
download_model(
|
||||
size_or_id="small" if device == "cuda" else "tiny",
|
||||
local_files_only=False,
|
||||
cache_dir=os.path.join(MODEL_CACHE_DIR, "whisper"),
|
||||
)
|
||||
logger.debug("Whisper audio transcription model downloaded")
|
||||
|
||||
else:
|
||||
# small model as default
|
||||
download_path = os.path.join(MODEL_CACHE_DIR, "sherpa-onnx")
|
||||
HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
||||
self.model_files = {
|
||||
"encoder.onnx": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/encoder-epoch-99-avg-1-chunk-16-left-128.onnx",
|
||||
"decoder.onnx": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/decoder-epoch-99-avg-1-chunk-16-left-128.onnx",
|
||||
"joiner.onnx": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/joiner-epoch-99-avg-1-chunk-16-left-128.onnx",
|
||||
"tokens.txt": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/tokens.txt",
|
||||
}
|
||||
|
||||
if not all(
|
||||
os.path.exists(os.path.join(download_path, n))
|
||||
for n in self.model_files.keys()
|
||||
):
|
||||
self.downloader = ModelDownloader(
|
||||
model_name="sherpa-onnx",
|
||||
download_path=download_path,
|
||||
file_names=self.model_files.keys(),
|
||||
download_func=self.__download_models,
|
||||
)
|
||||
self.downloader.ensure_model_files()
|
||||
self.downloader.wait_for_download()
|
||||
|
||||
self.model = sherpa_onnx.OnlineRecognizer.from_transducer(
|
||||
tokens=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/tokens.txt"),
|
||||
encoder=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/encoder.onnx"),
|
||||
decoder=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/decoder.onnx"),
|
||||
joiner=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/joiner.onnx"),
|
||||
num_threads=2,
|
||||
sample_rate=16000,
|
||||
feature_dim=80,
|
||||
enable_endpoint_detection=True,
|
||||
rule1_min_trailing_silence=2.4,
|
||||
rule2_min_trailing_silence=1.2,
|
||||
rule3_min_utterance_length=300,
|
||||
decoding_method="greedy_search",
|
||||
provider="cpu",
|
||||
)
|
||||
|
||||
def __download_models(self, path: str) -> None:
|
||||
try:
|
||||
file_name = os.path.basename(path)
|
||||
ModelDownloader.download_from_url(self.model_files[file_name], path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download {path}: {e}")
|
||||
Reference in New Issue
Block a user