Add ability to configure when custom classification models run (#18380)

* Add config to control when classification models are run

* Cleanup
This commit is contained in:
Nicolas Mowen 2025-05-24 10:18:46 -06:00 committed by GitHub
parent 3892f8c732
commit 87d0102624
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 10 deletions

View File

@ -38,12 +38,24 @@ class CustomClassificationStateCameraConfig(FrigateBaseModel):
crop: list[int, int, int, int] = Field( crop: list[int, int, int, int] = Field(
title="Crop of image frame on this camera to run classification on." title="Crop of image frame on this camera to run classification on."
) )
threshold: float = Field(
default=0.8, title="Classification score threshold to change the state."
)
class CustomClassificationStateConfig(FrigateBaseModel): class CustomClassificationStateConfig(FrigateBaseModel):
cameras: Dict[str, CustomClassificationStateCameraConfig] = Field( cameras: Dict[str, CustomClassificationStateCameraConfig] = Field(
title="Cameras to run classification on." title="Cameras to run classification on."
) )
motion: bool = Field(
default=False,
title="If classification should be run when motion is detected in the crop.",
)
interval: int | None = Field(
default=None,
title="Interval to run classification on in seconds.",
gt=0,
)
class CustomClassificationObjectConfig(FrigateBaseModel): class CustomClassificationObjectConfig(FrigateBaseModel):

View File

@ -1,5 +1,6 @@
"""Real time processor that works with classification tflite models.""" """Real time processor that works with classification tflite models."""
import datetime
import logging import logging
from typing import Any from typing import Any
@ -10,10 +11,11 @@ from frigate.comms.event_metadata_updater import (
EventMetadataPublisher, EventMetadataPublisher,
EventMetadataTypeEnum, EventMetadataTypeEnum,
) )
from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.config.classification import CustomClassificationConfig from frigate.config.classification import CustomClassificationConfig
from frigate.util.builtin import load_labels from frigate.util.builtin import load_labels
from frigate.util.object import calculate_region from frigate.util.object import box_overlaps, calculate_region
from ..types import DataProcessorMetrics from ..types import DataProcessorMetrics
from .api import RealTimeProcessorApi from .api import RealTimeProcessorApi
@ -31,14 +33,19 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
self, self,
config: FrigateConfig, config: FrigateConfig,
model_config: CustomClassificationConfig, model_config: CustomClassificationConfig,
name: str,
requestor: InterProcessRequestor,
metrics: DataProcessorMetrics, metrics: DataProcessorMetrics,
): ):
super().__init__(config, metrics) super().__init__(config, metrics)
self.model_config = model_config self.model_config = model_config
self.name = name
self.requestor = requestor
self.interpreter: Interpreter = None self.interpreter: Interpreter = None
self.tensor_input_details: dict[str, Any] = None self.tensor_input_details: dict[str, Any] = None
self.tensor_output_details: dict[str, Any] = None self.tensor_output_details: dict[str, Any] = None
self.labelmap: dict[int, str] = {} self.labelmap: dict[int, str] = {}
self.last_run = datetime.datetime.now().timestamp()
self.__build_detector() self.__build_detector()
def __build_detector(self) -> None: def __build_detector(self) -> None:
@ -53,16 +60,46 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
camera = frame_data.get("camera") camera = frame_data.get("camera")
if camera not in self.model_config.state_config.cameras: if camera not in self.model_config.state_config.cameras:
return return
camera_config = self.model_config.state_config.cameras[camera] camera_config = self.model_config.state_config.cameras[camera]
x, y, x2, y2 = calculate_region( crop = [
frame.shape,
camera_config.crop[0], camera_config.crop[0],
camera_config.crop[1], camera_config.crop[1],
camera_config.crop[2], camera_config.crop[2],
camera_config.crop[3], camera_config.crop[3],
]
should_run = False
now = datetime.datetime.now().timestamp()
if (
self.model_config.state_config.interval
and now > self.last_run + self.model_config.state_config.interval
):
self.last_run = now
should_run = True
if (
not should_run
and self.model_config.state_config.motion
and any([box_overlaps(crop, mb) for mb in frame_data.get("motion", [])])
):
# classification should run at most once per second
if now > self.last_run + 1:
self.last_run = now
should_run = True
if not should_run:
return
x, y, x2, y2 = calculate_region(
frame.shape,
crop[0],
crop[1],
crop[2],
crop[3],
224, 224,
1.0, 1.0,
) )
@ -82,12 +119,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
res: np.ndarray = self.interpreter.get_tensor( res: np.ndarray = self.interpreter.get_tensor(
self.tensor_output_details[0]["index"] self.tensor_output_details[0]["index"]
)[0] )[0]
print(f"the gate res is {res}")
probs = res / res.sum(axis=0) probs = res / res.sum(axis=0)
best_id = np.argmax(probs) best_id = np.argmax(probs)
score = round(probs[best_id], 2) score = round(probs[best_id], 2)
print(f"got {self.labelmap[best_id]} with score {score}") if score >= camera_config.threshold:
self.requestor.send_data(
f"{camera}/classification/{self.name}", self.labelmap[best_id]
)
def handle_request(self, topic, request_data): def handle_request(self, topic, request_data):
return None return None

View File

@ -147,13 +147,15 @@ class EmbeddingMaintainer(threading.Thread):
) )
) )
for model in self.config.classification.custom.values(): for name, model_config in self.config.classification.custom.items():
self.realtime_processors.append( self.realtime_processors.append(
CustomStateClassificationProcessor(self.config, model, self.metrics) CustomStateClassificationProcessor(
if model.state_config != None self.config, model_config, name, self.requestor, self.metrics
)
if model_config.state_config != None
else CustomObjectClassificationProcessor( else CustomObjectClassificationProcessor(
self.config, self.config,
model, model_config,
self.event_metadata_publisher, self.event_metadata_publisher,
self.metrics, self.metrics,
) )
@ -504,7 +506,9 @@ class EmbeddingMaintainer(threading.Thread):
processor.process_frame(camera, yuv_frame, True) processor.process_frame(camera, yuv_frame, True)
if isinstance(processor, CustomStateClassificationProcessor): if isinstance(processor, CustomStateClassificationProcessor):
processor.process_frame({"camera": camera}, yuv_frame) processor.process_frame(
{"camera": camera, "motion": motion_boxes}, yuv_frame
)
self.frame_manager.close(frame_name) self.frame_manager.close(frame_name)