From 731743c7e5df99046d8c818c53113930ab3b35cb Mon Sep 17 00:00:00 2001 From: Sergey Krashevich Date: Fri, 1 Sep 2023 15:00:11 +0300 Subject: [PATCH] Refactor AudioTfl class to accept the number of detection threads as a parameter in the constructor, and update the usage of the num_threads attribute accordingly (#7588) --- frigate/config.py | 1 + frigate/events/audio.py | 11 ++++++----- frigate/util/builtin.py | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/frigate/config.py b/frigate/config.py index 7f1624ed4..f98da3855 100644 --- a/frigate/config.py +++ b/frigate/config.py @@ -444,6 +444,7 @@ class AudioConfig(FrigateBaseModel): enabled_in_config: Optional[bool] = Field( title="Keep track of original state of audio detection." ) + num_threads: int = Field(default=2, title="Number of detection threads", ge=1) class BirdseyeModeEnum(str, Enum): diff --git a/frigate/events/audio.py b/frigate/events/audio.py index 911013a38..76bfd5fa8 100644 --- a/frigate/events/audio.py +++ b/frigate/events/audio.py @@ -89,12 +89,13 @@ def listen_to_audio( class AudioTfl: - def __init__(self, stop_event: mp.Event): + def __init__(self, stop_event: mp.Event, num_threads=2): self.stop_event = stop_event - self.labels = load_labels("/audio-labelmap.txt") + self.num_threads = num_threads + self.labels = load_labels("/audio-labelmap.txt", prefill=521) self.interpreter = Interpreter( model_path="/cpu_audio_model.tflite", - num_threads=2, + num_threads=self.num_threads, ) self.interpreter.allocate_tensors() @@ -117,7 +118,7 @@ class AudioTfl: count = len(scores) for i in range(count): - if scores[i] < 0.4 or i == 20: + if scores[i] < AUDIO_MIN_CONFIDENCE or i == 20: break detections[i] = [ class_ids[i], @@ -164,7 +165,7 @@ class AudioEventMaintainer(threading.Thread): self.inter_process_communicator = inter_process_communicator self.detections: dict[dict[str, any]] = feature_metrics self.stop_event = stop_event - self.detector = AudioTfl(stop_event) + self.detector = AudioTfl(stop_event, self.config.audio.num_threads) self.shape = (int(round(AUDIO_DURATION * AUDIO_SAMPLE_RATE)),) self.chunk_size = int(round(AUDIO_DURATION * AUDIO_SAMPLE_RATE * 2)) self.logger = logging.getLogger(f"audio.{self.config.name}") diff --git a/frigate/util/builtin.py b/frigate/util/builtin.py index 7eafc9d33..9c776f5df 100644 --- a/frigate/util/builtin.py +++ b/frigate/util/builtin.py @@ -134,7 +134,7 @@ def get_ffmpeg_arg_list(arg: Any) -> list: return arg if isinstance(arg, list) else shlex.split(arg) -def load_labels(path, encoding="utf-8"): +def load_labels(path, encoding="utf-8", prefill=91): """Loads labels from file (with or without index numbers). Args: path: path to label file. @@ -143,7 +143,7 @@ def load_labels(path, encoding="utf-8"): Dictionary mapping indices to labels. """ with open(path, "r", encoding=encoding) as f: - labels = {index: "unknown" for index in range(91)} + labels = {index: "unknown" for index in range(prefill)} lines = f.readlines() if not lines: return {}