Add ability to configure when custom classification models run (#18380)

* Add config to control when classification models are run

* Cleanup
This commit is contained in:
Nicolas Mowen 2025-05-24 10:18:46 -06:00 committed by GitHub
parent 3892f8c732
commit 87d0102624
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 10 deletions

View File

@ -38,12 +38,24 @@ class CustomClassificationStateCameraConfig(FrigateBaseModel):
crop: list[int, int, int, int] = Field(
title="Crop of image frame on this camera to run classification on."
)
threshold: float = Field(
default=0.8, title="Classification score threshold to change the state."
)
class CustomClassificationStateConfig(FrigateBaseModel):
cameras: Dict[str, CustomClassificationStateCameraConfig] = Field(
title="Cameras to run classification on."
)
motion: bool = Field(
default=False,
title="If classification should be run when motion is detected in the crop.",
)
interval: int | None = Field(
default=None,
title="Interval to run classification on in seconds.",
gt=0,
)
class CustomClassificationObjectConfig(FrigateBaseModel):

View File

@ -1,5 +1,6 @@
"""Real time processor that works with classification tflite models."""
import datetime
import logging
from typing import Any
@ -10,10 +11,11 @@ from frigate.comms.event_metadata_updater import (
EventMetadataPublisher,
EventMetadataTypeEnum,
)
from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig
from frigate.config.classification import CustomClassificationConfig
from frigate.util.builtin import load_labels
from frigate.util.object import calculate_region
from frigate.util.object import box_overlaps, calculate_region
from ..types import DataProcessorMetrics
from .api import RealTimeProcessorApi
@ -31,14 +33,19 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
self,
config: FrigateConfig,
model_config: CustomClassificationConfig,
name: str,
requestor: InterProcessRequestor,
metrics: DataProcessorMetrics,
):
super().__init__(config, metrics)
self.model_config = model_config
self.name = name
self.requestor = requestor
self.interpreter: Interpreter = None
self.tensor_input_details: dict[str, Any] = None
self.tensor_output_details: dict[str, Any] = None
self.labelmap: dict[int, str] = {}
self.last_run = datetime.datetime.now().timestamp()
self.__build_detector()
def __build_detector(self) -> None:
@ -53,16 +60,46 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
camera = frame_data.get("camera")
if camera not in self.model_config.state_config.cameras:
return
camera_config = self.model_config.state_config.cameras[camera]
x, y, x2, y2 = calculate_region(
frame.shape,
crop = [
camera_config.crop[0],
camera_config.crop[1],
camera_config.crop[2],
camera_config.crop[3],
]
should_run = False
now = datetime.datetime.now().timestamp()
if (
self.model_config.state_config.interval
and now > self.last_run + self.model_config.state_config.interval
):
self.last_run = now
should_run = True
if (
not should_run
and self.model_config.state_config.motion
and any([box_overlaps(crop, mb) for mb in frame_data.get("motion", [])])
):
# classification should run at most once per second
if now > self.last_run + 1:
self.last_run = now
should_run = True
if not should_run:
return
x, y, x2, y2 = calculate_region(
frame.shape,
crop[0],
crop[1],
crop[2],
crop[3],
224,
1.0,
)
@ -82,12 +119,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
res: np.ndarray = self.interpreter.get_tensor(
self.tensor_output_details[0]["index"]
)[0]
print(f"the gate res is {res}")
probs = res / res.sum(axis=0)
best_id = np.argmax(probs)
score = round(probs[best_id], 2)
print(f"got {self.labelmap[best_id]} with score {score}")
if score >= camera_config.threshold:
self.requestor.send_data(
f"{camera}/classification/{self.name}", self.labelmap[best_id]
)
def handle_request(self, topic, request_data):
return None

View File

@ -147,13 +147,15 @@ class EmbeddingMaintainer(threading.Thread):
)
)
for model in self.config.classification.custom.values():
for name, model_config in self.config.classification.custom.items():
self.realtime_processors.append(
CustomStateClassificationProcessor(self.config, model, self.metrics)
if model.state_config != None
CustomStateClassificationProcessor(
self.config, model_config, name, self.requestor, self.metrics
)
if model_config.state_config != None
else CustomObjectClassificationProcessor(
self.config,
model,
model_config,
self.event_metadata_publisher,
self.metrics,
)
@ -504,7 +506,9 @@ class EmbeddingMaintainer(threading.Thread):
processor.process_frame(camera, yuv_frame, True)
if isinstance(processor, CustomStateClassificationProcessor):
processor.process_frame({"camera": camera}, yuv_frame)
processor.process_frame(
{"camera": camera, "motion": motion_boxes}, yuv_frame
)
self.frame_manager.close(frame_name)