From 87d0102624adc4e15000c6e41f33762f76fcc95d Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Sat, 24 May 2025 10:18:46 -0600 Subject: [PATCH] Add ability to configure when custom classification models run (#18380) * Add config to control when classification models are run * Cleanup --- frigate/config/classification.py | 12 +++++ .../real_time/custom_classification.py | 49 +++++++++++++++++-- frigate/embeddings/maintainer.py | 14 ++++-- 3 files changed, 65 insertions(+), 10 deletions(-) diff --git a/frigate/config/classification.py b/frigate/config/classification.py index 284136076..4af60df4f 100644 --- a/frigate/config/classification.py +++ b/frigate/config/classification.py @@ -38,12 +38,24 @@ class CustomClassificationStateCameraConfig(FrigateBaseModel): crop: list[int, int, int, int] = Field( title="Crop of image frame on this camera to run classification on." ) + threshold: float = Field( + default=0.8, title="Classification score threshold to change the state." + ) class CustomClassificationStateConfig(FrigateBaseModel): cameras: Dict[str, CustomClassificationStateCameraConfig] = Field( title="Cameras to run classification on." ) + motion: bool = Field( + default=False, + title="If classification should be run when motion is detected in the crop.", + ) + interval: int | None = Field( + default=None, + title="Interval to run classification on in seconds.", + gt=0, + ) class CustomClassificationObjectConfig(FrigateBaseModel): diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 1848968bb..cd99508c9 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -1,5 +1,6 @@ """Real time processor that works with classification tflite models.""" +import datetime import logging from typing import Any @@ -10,10 +11,11 @@ from frigate.comms.event_metadata_updater import ( EventMetadataPublisher, EventMetadataTypeEnum, ) +from frigate.comms.inter_process import InterProcessRequestor from frigate.config import FrigateConfig from frigate.config.classification import CustomClassificationConfig from frigate.util.builtin import load_labels -from frigate.util.object import calculate_region +from frigate.util.object import box_overlaps, calculate_region from ..types import DataProcessorMetrics from .api import RealTimeProcessorApi @@ -31,14 +33,19 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): self, config: FrigateConfig, model_config: CustomClassificationConfig, + name: str, + requestor: InterProcessRequestor, metrics: DataProcessorMetrics, ): super().__init__(config, metrics) self.model_config = model_config + self.name = name + self.requestor = requestor self.interpreter: Interpreter = None self.tensor_input_details: dict[str, Any] = None self.tensor_output_details: dict[str, Any] = None self.labelmap: dict[int, str] = {} + self.last_run = datetime.datetime.now().timestamp() self.__build_detector() def __build_detector(self) -> None: @@ -53,16 +60,46 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): camera = frame_data.get("camera") + if camera not in self.model_config.state_config.cameras: return camera_config = self.model_config.state_config.cameras[camera] - x, y, x2, y2 = calculate_region( - frame.shape, + crop = [ camera_config.crop[0], camera_config.crop[1], camera_config.crop[2], camera_config.crop[3], + ] + should_run = False + + now = datetime.datetime.now().timestamp() + if ( + self.model_config.state_config.interval + and now > self.last_run + self.model_config.state_config.interval + ): + self.last_run = now + should_run = True + + if ( + not should_run + and self.model_config.state_config.motion + and any([box_overlaps(crop, mb) for mb in frame_data.get("motion", [])]) + ): + # classification should run at most once per second + if now > self.last_run + 1: + self.last_run = now + should_run = True + + if not should_run: + return + + x, y, x2, y2 = calculate_region( + frame.shape, + crop[0], + crop[1], + crop[2], + crop[3], 224, 1.0, ) @@ -82,12 +119,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): res: np.ndarray = self.interpreter.get_tensor( self.tensor_output_details[0]["index"] )[0] - print(f"the gate res is {res}") probs = res / res.sum(axis=0) best_id = np.argmax(probs) score = round(probs[best_id], 2) - print(f"got {self.labelmap[best_id]} with score {score}") + if score >= camera_config.threshold: + self.requestor.send_data( + f"{camera}/classification/{self.name}", self.labelmap[best_id] + ) def handle_request(self, topic, request_data): return None diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index 9838f4a21..6cce9ba98 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -147,13 +147,15 @@ class EmbeddingMaintainer(threading.Thread): ) ) - for model in self.config.classification.custom.values(): + for name, model_config in self.config.classification.custom.items(): self.realtime_processors.append( - CustomStateClassificationProcessor(self.config, model, self.metrics) - if model.state_config != None + CustomStateClassificationProcessor( + self.config, model_config, name, self.requestor, self.metrics + ) + if model_config.state_config != None else CustomObjectClassificationProcessor( self.config, - model, + model_config, self.event_metadata_publisher, self.metrics, ) @@ -504,7 +506,9 @@ class EmbeddingMaintainer(threading.Thread): processor.process_frame(camera, yuv_frame, True) if isinstance(processor, CustomStateClassificationProcessor): - processor.process_frame({"camera": camera}, yuv_frame) + processor.process_frame( + {"camera": camera, "motion": motion_boxes}, yuv_frame + ) self.frame_manager.close(frame_name)