From eb1fe9fe2072f9a79c72f7a4b28b53328719a2d5 Mon Sep 17 00:00:00 2001 From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com> Date: Tue, 3 Jun 2025 06:53:48 -0500 Subject: [PATCH] Audio transcription tweaks (#18540) * use model runner * unload whisper model when live transcription is complete --- .../common/audio_transcription/model.py | 81 +++++++ .../real_time/audio_transcription.py | 205 +++++++++--------- .../real_time/whisper_online.py | 5 +- frigate/data_processing/types.py | 7 + frigate/events/audio.py | 42 ++-- 5 files changed, 220 insertions(+), 120 deletions(-) create mode 100644 frigate/data_processing/common/audio_transcription/model.py diff --git a/frigate/data_processing/common/audio_transcription/model.py b/frigate/data_processing/common/audio_transcription/model.py new file mode 100644 index 000000000..0fe5ddb5c --- /dev/null +++ b/frigate/data_processing/common/audio_transcription/model.py @@ -0,0 +1,81 @@ +"""Set up audio transcription models based on model size.""" + +import logging +import os + +import sherpa_onnx +from faster_whisper.utils import download_model + +from frigate.comms.inter_process import InterProcessRequestor +from frigate.const import MODEL_CACHE_DIR +from frigate.data_processing.types import AudioTranscriptionModel +from frigate.util.downloader import ModelDownloader + +logger = logging.getLogger(__name__) + + +class AudioTranscriptionModelRunner: + def __init__( + self, + device: str = "CPU", + model_size: str = "small", + ): + self.model: AudioTranscriptionModel = None + self.requestor = InterProcessRequestor() + + if model_size == "large": + # use the Whisper download function instead of our own + logger.debug("Downloading Whisper audio transcription model") + download_model( + size_or_id="small" if device == "cuda" else "tiny", + local_files_only=False, + cache_dir=os.path.join(MODEL_CACHE_DIR, "whisper"), + ) + logger.debug("Whisper audio transcription model downloaded") + + else: + # small model as default + download_path = os.path.join(MODEL_CACHE_DIR, "sherpa-onnx") + HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") + self.model_files = { + "encoder.onnx": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/encoder-epoch-99-avg-1-chunk-16-left-128.onnx", + "decoder.onnx": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/decoder-epoch-99-avg-1-chunk-16-left-128.onnx", + "joiner.onnx": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/joiner-epoch-99-avg-1-chunk-16-left-128.onnx", + "tokens.txt": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/tokens.txt", + } + + if not all( + os.path.exists(os.path.join(download_path, n)) + for n in self.model_files.keys() + ): + self.downloader = ModelDownloader( + model_name="sherpa-onnx", + download_path=download_path, + file_names=self.model_files.keys(), + download_func=self.__download_models, + ) + self.downloader.ensure_model_files() + self.downloader.wait_for_download() + + self.model = sherpa_onnx.OnlineRecognizer.from_transducer( + tokens=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/tokens.txt"), + encoder=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/encoder.onnx"), + decoder=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/decoder.onnx"), + joiner=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/joiner.onnx"), + num_threads=2, + sample_rate=16000, + feature_dim=80, + enable_endpoint_detection=True, + rule1_min_trailing_silence=2.4, + rule2_min_trailing_silence=1.2, + rule3_min_utterance_length=300, + decoding_method="greedy_search", + provider="cpu", + ) + + def __download_models(self, path: str) -> None: + try: + file_name = os.path.basename(path) + ModelDownloader.download_from_url(self.model_files[file_name], path) + except Exception as e: + logger.error(f"Failed to download {path}: {e}") diff --git a/frigate/data_processing/real_time/audio_transcription.py b/frigate/data_processing/real_time/audio_transcription.py index 7ed644498..2e6d599eb 100644 --- a/frigate/data_processing/real_time/audio_transcription.py +++ b/frigate/data_processing/real_time/audio_transcription.py @@ -7,16 +7,20 @@ import threading from typing import Optional import numpy as np -import sherpa_onnx from frigate.comms.inter_process import InterProcessRequestor from frigate.config import CameraConfig, FrigateConfig from frigate.const import MODEL_CACHE_DIR -from frigate.util.downloader import ModelDownloader +from frigate.data_processing.common.audio_transcription.model import ( + AudioTranscriptionModelRunner, +) +from frigate.data_processing.real_time.whisper_online import ( + FasterWhisperASR, + OnlineASRProcessor, +) from ..types import DataProcessorMetrics from .api import RealTimeProcessorApi -from .whisper_online import FasterWhisperASR, OnlineASRProcessor logger = logging.getLogger(__name__) @@ -27,6 +31,7 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi): config: FrigateConfig, camera_config: CameraConfig, requestor: InterProcessRequestor, + model_runner: AudioTranscriptionModelRunner, metrics: DataProcessorMetrics, stop_event: threading.Event, ): @@ -34,95 +39,55 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi): self.config = config self.camera_config = camera_config self.requestor = requestor - self.recognizer = None self.stream = None + self.whisper_model = None + self.model_runner = model_runner self.transcription_segments = [] self.audio_queue = queue.Queue() self.stop_event = stop_event - if self.config.audio_transcription.model_size == "large": - self.asr = FasterWhisperASR( - modelsize="tiny", - device="cuda" - if self.config.audio_transcription.device == "GPU" - else "cpu", - lan=config.audio_transcription.language, - model_dir=os.path.join(MODEL_CACHE_DIR, "whisper"), - ) - self.asr.use_vad() # Enable Silero VAD for low-RMS audio - - else: - # small model as default - download_path = os.path.join(MODEL_CACHE_DIR, "sherpa-onnx") - HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") - self.model_files = { - "encoder.onnx": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/encoder-epoch-99-avg-1-chunk-16-left-128.onnx", - "decoder.onnx": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/decoder-epoch-99-avg-1-chunk-16-left-128.onnx", - "joiner.onnx": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/joiner-epoch-99-avg-1-chunk-16-left-128.onnx", - "tokens.txt": f"{HF_ENDPOINT}/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26/resolve/main/tokens.txt", - } - - if not all( - os.path.exists(os.path.join(download_path, n)) - for n in self.model_files.keys() - ): - self.downloader = ModelDownloader( - model_name="sherpa-onnx", - download_path=download_path, - file_names=self.model_files.keys(), - download_func=self.__download_models, - complete_func=self.__build_recognizer, - ) - self.downloader.ensure_model_files() - - self.__build_recognizer() - - def __download_models(self, path: str) -> None: - try: - file_name = os.path.basename(path) - ModelDownloader.download_from_url(self.model_files[file_name], path) - except Exception as e: - logger.error(f"Failed to download {path}: {e}") - def __build_recognizer(self) -> None: try: if self.config.audio_transcription.model_size == "large": - self.online = OnlineASRProcessor( - asr=self.asr, + # Whisper models need to be per-process and can only run one stream at a time + # TODO: try parallel: https://github.com/SYSTRAN/faster-whisper/issues/100 + logger.debug(f"Loading Whisper model for {self.camera_config.name}") + self.whisper_model = FasterWhisperASR( + modelsize="tiny", + device="cuda" + if self.config.audio_transcription.device == "GPU" + else "cpu", + lan=self.config.audio_transcription.language, + model_dir=os.path.join(MODEL_CACHE_DIR, "whisper"), + ) + self.whisper_model.use_vad() + self.stream = OnlineASRProcessor( + asr=self.whisper_model, ) else: - self.recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( - tokens=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/tokens.txt"), - encoder=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/encoder.onnx"), - decoder=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/decoder.onnx"), - joiner=os.path.join(MODEL_CACHE_DIR, "sherpa-onnx/joiner.onnx"), - num_threads=2, - sample_rate=16000, - feature_dim=80, - enable_endpoint_detection=True, - rule1_min_trailing_silence=2.4, - rule2_min_trailing_silence=1.2, - rule3_min_utterance_length=300, - decoding_method="greedy_search", - provider="cpu", - ) - self.stream = self.recognizer.create_stream() - logger.debug("Audio transcription (live) initialized") + logger.debug(f"Loading sherpa stream for {self.camera_config.name}") + self.stream = self.model_runner.model.create_stream() + logger.debug( + f"Audio transcription (live) initialized for {self.camera_config.name}" + ) except Exception as e: logger.error( f"Failed to initialize live streaming audio transcription: {e}" ) - self.recognizer = None def __process_audio_stream( self, audio_data: np.ndarray ) -> Optional[tuple[str, bool]]: - if (not self.recognizer or not self.stream) and not self.online: - logger.debug( - "Audio transcription (streaming) recognizer or stream not initialized" - ) + if ( + self.model_runner.model is None + and self.config.audio_transcription.model_size == "small" + ): + logger.debug("Audio transcription (live) model not initialized") return None + if not self.stream: + self.__build_recognizer() + try: if audio_data.dtype != np.float32: audio_data = audio_data.astype(np.float32) @@ -135,10 +100,14 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi): if self.config.audio_transcription.model_size == "large": # large model - self.online.insert_audio_chunk(audio_data) - output = self.online.process_iter() + self.stream.insert_audio_chunk(audio_data) + output = self.stream.process_iter() text = output[2].strip() - is_endpoint = text.endswith((".", "!", "?")) + is_endpoint = ( + text.endswith((".", "!", "?")) + and sum(len(str(lines)) for lines in self.transcription_segments) + > 300 + ) if text: self.transcription_segments.append(text) @@ -150,11 +119,11 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi): # small model self.stream.accept_waveform(16000, audio_data) - while self.recognizer.is_ready(self.stream): - self.recognizer.decode_stream(self.stream) + while self.model_runner.model.is_ready(self.stream): + self.model_runner.model.decode_stream(self.stream) - text = self.recognizer.get_result(self.stream).strip() - is_endpoint = self.recognizer.is_endpoint(self.stream) + text = self.model_runner.model.get_result(self.stream).strip() + is_endpoint = self.model_runner.model.is_endpoint(self.stream) logger.debug(f"Transcription result: '{text}'") @@ -166,7 +135,7 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi): if is_endpoint and self.config.audio_transcription.model_size == "small": # reset sherpa if we've reached an endpoint - self.recognizer.reset(self.stream) + self.model_runner.model.reset(self.stream) return text, is_endpoint except Exception as e: @@ -190,10 +159,17 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi): logger.debug( f"Starting audio transcription thread for {self.camera_config.name}" ) + + # start with an empty transcription + self.requestor.send_data( + f"{self.camera_config.name}/audio/transcription", + "", + ) + while not self.stop_event.is_set(): try: # Get audio data from queue with a timeout to check stop_event - obj_data, audio = self.audio_queue.get(timeout=0.1) + _, audio = self.audio_queue.get(timeout=0.1) result = self.__process_audio_stream(audio) if not result: @@ -209,7 +185,7 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi): self.audio_queue.task_done() if is_endpoint: - self.reset(obj_data["camera"]) + self.reset() except queue.Empty: continue @@ -221,23 +197,7 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi): f"Stopping audio transcription thread for {self.camera_config.name}" ) - def reset(self, camera: str) -> None: - if self.config.audio_transcription.model_size == "large": - # get final output from whisper - output = self.online.finish() - self.transcription_segments = [] - - self.requestor.send_data( - f"{self.camera_config.name}/audio/transcription", - (output[2].strip() + " "), - ) - - # reset whisper - self.online.init() - else: - # reset sherpa - self.recognizer.reset(self.stream) - + def clear_audio_queue(self) -> None: # Clear the audio queue while not self.audio_queue.empty(): try: @@ -246,8 +206,54 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi): except queue.Empty: break + def reset(self) -> None: + if self.config.audio_transcription.model_size == "large": + # get final output from whisper + output = self.stream.finish() + self.transcription_segments = [] + + self.requestor.send_data( + f"{self.camera_config.name}/audio/transcription", + (output[2].strip() + " "), + ) + + # reset whisper + self.stream.init() + self.transcription_segments = [] + else: + # reset sherpa + self.model_runner.model.reset(self.stream) + logger.debug("Stream reset") + def check_unload_model(self) -> None: + # regularly called in the loop in audio maintainer + if ( + self.config.audio_transcription.model_size == "large" + and self.whisper_model is not None + ): + logger.debug(f"Unloading Whisper model for {self.camera_config.name}") + self.clear_audio_queue() + self.transcription_segments = [] + self.stream = None + self.whisper_model = None + + self.requestor.send_data( + f"{self.camera_config.name}/audio/transcription", + "", + ) + if ( + self.config.audio_transcription.model_size == "small" + and self.stream is not None + ): + logger.debug(f"Clearing sherpa stream for {self.camera_config.name}") + self.stream = None + + self.requestor.send_data( + f"{self.camera_config.name}/audio/transcription", + "", + ) + def stop(self) -> None: """Stop the transcription thread and clean up.""" self.stop_event.set() @@ -266,7 +272,6 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi): self, topic: str, request_data: dict[str, any] ) -> dict[str, any] | None: if topic == "clear_audio_recognizer": - self.recognizer = None self.stream = None self.__build_recognizer() return {"message": "Audio recognizer cleared and rebuilt", "success": True} diff --git a/frigate/data_processing/real_time/whisper_online.py b/frigate/data_processing/real_time/whisper_online.py index 96c1ce0cf..9b81d7fbe 100644 --- a/frigate/data_processing/real_time/whisper_online.py +++ b/frigate/data_processing/real_time/whisper_online.py @@ -139,8 +139,11 @@ class FasterWhisperASR(ASRBase): return model def transcribe(self, audio, init_prompt=""): + from faster_whisper import BatchedInferencePipeline + # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01) - segments, info = self.model.transcribe( + batched_model = BatchedInferencePipeline(model=self.model) + segments, info = batched_model.transcribe( audio, language=self.original_language, initial_prompt=init_prompt, diff --git a/frigate/data_processing/types.py b/frigate/data_processing/types.py index a19a856bf..5d083b32e 100644 --- a/frigate/data_processing/types.py +++ b/frigate/data_processing/types.py @@ -4,6 +4,10 @@ import multiprocessing as mp from enum import Enum from multiprocessing.sharedctypes import Synchronized +import sherpa_onnx + +from frigate.data_processing.real_time.whisper_online import FasterWhisperASR + class DataProcessorMetrics: image_embeddings_speed: Synchronized @@ -41,3 +45,6 @@ class PostProcessDataEnum(str, Enum): recording = "recording" review = "review" tracked_object = "tracked_object" + + +AudioTranscriptionModel = FasterWhisperASR | sherpa_onnx.OnlineRecognizer | None diff --git a/frigate/events/audio.py b/frigate/events/audio.py index dc6ee7128..aeeaf3b4f 100644 --- a/frigate/events/audio.py +++ b/frigate/events/audio.py @@ -30,6 +30,9 @@ from frigate.const import ( AUDIO_MIN_CONFIDENCE, AUDIO_SAMPLE_RATE, ) +from frigate.data_processing.common.audio_transcription.model import ( + AudioTranscriptionModelRunner, +) from frigate.data_processing.real_time.audio_transcription import ( AudioTranscriptionRealTimeProcessor, ) @@ -87,6 +90,10 @@ class AudioProcessor(util.Process): self.camera_metrics = camera_metrics self.cameras = cameras self.config = config + self.transcription_model_runner = AudioTranscriptionModelRunner( + self.config.audio_transcription.device, + self.config.audio_transcription.model_size, + ) def run(self) -> None: audio_threads: list[AudioEventMaintainer] = [] @@ -101,6 +108,7 @@ class AudioProcessor(util.Process): camera, self.config, self.camera_metrics, + self.transcription_model_runner, self.stop_event, ) audio_threads.append(audio_thread) @@ -130,6 +138,7 @@ class AudioEventMaintainer(threading.Thread): camera: CameraConfig, config: FrigateConfig, camera_metrics: dict[str, CameraMetrics], + audio_transcription_model_runner: AudioTranscriptionModelRunner, stop_event: threading.Event, ) -> None: super().__init__(name=f"{camera.name}_audio_event_processor") @@ -146,6 +155,7 @@ class AudioEventMaintainer(threading.Thread): self.ffmpeg_cmd = get_ffmpeg_command(self.camera_config.ffmpeg) self.logpipe = LogPipe(f"ffmpeg.{self.camera_config.name}.audio") self.audio_listener = None + self.audio_transcription_model_runner = audio_transcription_model_runner self.transcription_processor = None self.transcription_thread = None @@ -168,6 +178,7 @@ class AudioEventMaintainer(threading.Thread): config=self.config, camera_config=self.camera_config, requestor=self.requestor, + model_runner=self.audio_transcription_model_runner, metrics=self.camera_metrics[self.camera_config.name], stop_event=self.stop_event, ) @@ -223,18 +234,18 @@ class AudioEventMaintainer(threading.Thread): ) # run audio transcription - if self.transcription_processor is not None and ( - self.camera_config.audio_transcription.live_enabled - ): - self.transcribing = True - # process audio until we've reached the endpoint - self.transcription_processor.process_audio( - { - "id": f"{self.camera_config.name}_audio", - "camera": self.camera_config.name, - }, - audio, - ) + if self.transcription_processor is not None: + if self.camera_config.audio_transcription.live_enabled: + # process audio until we've reached the endpoint + self.transcription_processor.process_audio( + { + "id": f"{self.camera_config.name}_audio", + "camera": self.camera_config.name, + }, + audio, + ) + else: + self.transcription_processor.check_unload_model() self.expire_detections() @@ -309,13 +320,6 @@ class AudioEventMaintainer(threading.Thread): ) self.detections[detection["label"]] = None - # clear real-time transcription - if self.transcription_processor is not None: - self.transcription_processor.reset(self.camera_config.name) - self.requestor.send_data( - f"{self.camera_config.name}/audio/transcription", "" - ) - def expire_all_detections(self) -> None: """Immediately end all current detections""" now = datetime.datetime.now().timestamp()