mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-01-02 00:07:11 +01:00
Refactor AudioTfl class to accept the number of detection threads as a parameter in the constructor, and update the usage of the num_threads attribute accordingly (#7588)
This commit is contained in:
parent
7c629c1874
commit
731743c7e5
@ -444,6 +444,7 @@ class AudioConfig(FrigateBaseModel):
|
|||||||
enabled_in_config: Optional[bool] = Field(
|
enabled_in_config: Optional[bool] = Field(
|
||||||
title="Keep track of original state of audio detection."
|
title="Keep track of original state of audio detection."
|
||||||
)
|
)
|
||||||
|
num_threads: int = Field(default=2, title="Number of detection threads", ge=1)
|
||||||
|
|
||||||
|
|
||||||
class BirdseyeModeEnum(str, Enum):
|
class BirdseyeModeEnum(str, Enum):
|
||||||
|
@ -89,12 +89,13 @@ def listen_to_audio(
|
|||||||
|
|
||||||
|
|
||||||
class AudioTfl:
|
class AudioTfl:
|
||||||
def __init__(self, stop_event: mp.Event):
|
def __init__(self, stop_event: mp.Event, num_threads=2):
|
||||||
self.stop_event = stop_event
|
self.stop_event = stop_event
|
||||||
self.labels = load_labels("/audio-labelmap.txt")
|
self.num_threads = num_threads
|
||||||
|
self.labels = load_labels("/audio-labelmap.txt", prefill=521)
|
||||||
self.interpreter = Interpreter(
|
self.interpreter = Interpreter(
|
||||||
model_path="/cpu_audio_model.tflite",
|
model_path="/cpu_audio_model.tflite",
|
||||||
num_threads=2,
|
num_threads=self.num_threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.interpreter.allocate_tensors()
|
self.interpreter.allocate_tensors()
|
||||||
@ -117,7 +118,7 @@ class AudioTfl:
|
|||||||
count = len(scores)
|
count = len(scores)
|
||||||
|
|
||||||
for i in range(count):
|
for i in range(count):
|
||||||
if scores[i] < 0.4 or i == 20:
|
if scores[i] < AUDIO_MIN_CONFIDENCE or i == 20:
|
||||||
break
|
break
|
||||||
detections[i] = [
|
detections[i] = [
|
||||||
class_ids[i],
|
class_ids[i],
|
||||||
@ -164,7 +165,7 @@ class AudioEventMaintainer(threading.Thread):
|
|||||||
self.inter_process_communicator = inter_process_communicator
|
self.inter_process_communicator = inter_process_communicator
|
||||||
self.detections: dict[dict[str, any]] = feature_metrics
|
self.detections: dict[dict[str, any]] = feature_metrics
|
||||||
self.stop_event = stop_event
|
self.stop_event = stop_event
|
||||||
self.detector = AudioTfl(stop_event)
|
self.detector = AudioTfl(stop_event, self.config.audio.num_threads)
|
||||||
self.shape = (int(round(AUDIO_DURATION * AUDIO_SAMPLE_RATE)),)
|
self.shape = (int(round(AUDIO_DURATION * AUDIO_SAMPLE_RATE)),)
|
||||||
self.chunk_size = int(round(AUDIO_DURATION * AUDIO_SAMPLE_RATE * 2))
|
self.chunk_size = int(round(AUDIO_DURATION * AUDIO_SAMPLE_RATE * 2))
|
||||||
self.logger = logging.getLogger(f"audio.{self.config.name}")
|
self.logger = logging.getLogger(f"audio.{self.config.name}")
|
||||||
|
@ -134,7 +134,7 @@ def get_ffmpeg_arg_list(arg: Any) -> list:
|
|||||||
return arg if isinstance(arg, list) else shlex.split(arg)
|
return arg if isinstance(arg, list) else shlex.split(arg)
|
||||||
|
|
||||||
|
|
||||||
def load_labels(path, encoding="utf-8"):
|
def load_labels(path, encoding="utf-8", prefill=91):
|
||||||
"""Loads labels from file (with or without index numbers).
|
"""Loads labels from file (with or without index numbers).
|
||||||
Args:
|
Args:
|
||||||
path: path to label file.
|
path: path to label file.
|
||||||
@ -143,7 +143,7 @@ def load_labels(path, encoding="utf-8"):
|
|||||||
Dictionary mapping indices to labels.
|
Dictionary mapping indices to labels.
|
||||||
"""
|
"""
|
||||||
with open(path, "r", encoding=encoding) as f:
|
with open(path, "r", encoding=encoding) as f:
|
||||||
labels = {index: "unknown" for index in range(91)}
|
labels = {index: "unknown" for index in range(prefill)}
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
if not lines:
|
if not lines:
|
||||||
return {}
|
return {}
|
||||||
|
Loading…
Reference in New Issue
Block a user