Add ability to configure when custom classification models run (#18380)

* Add config to control when classification models are run

* Cleanup
This commit is contained in:
Nicolas Mowen
2025-05-24 10:18:46 -06:00
committed by Blake Blackshear
parent 53ff33135b
commit 723553edb7
3 changed files with 65 additions and 10 deletions

View File

@@ -1,5 +1,6 @@
"""Real time processor that works with classification tflite models."""
import datetime
import logging
from typing import Any
@@ -10,10 +11,11 @@ from frigate.comms.event_metadata_updater import (
EventMetadataPublisher,
EventMetadataTypeEnum,
)
from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig
from frigate.config.classification import CustomClassificationConfig
from frigate.util.builtin import load_labels
from frigate.util.object import calculate_region
from frigate.util.object import box_overlaps, calculate_region
from ..types import DataProcessorMetrics
from .api import RealTimeProcessorApi
@@ -31,14 +33,19 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
self,
config: FrigateConfig,
model_config: CustomClassificationConfig,
name: str,
requestor: InterProcessRequestor,
metrics: DataProcessorMetrics,
):
super().__init__(config, metrics)
self.model_config = model_config
self.name = name
self.requestor = requestor
self.interpreter: Interpreter = None
self.tensor_input_details: dict[str, Any] = None
self.tensor_output_details: dict[str, Any] = None
self.labelmap: dict[int, str] = {}
self.last_run = datetime.datetime.now().timestamp()
self.__build_detector()
def __build_detector(self) -> None:
@@ -53,16 +60,46 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
camera = frame_data.get("camera")
if camera not in self.model_config.state_config.cameras:
return
camera_config = self.model_config.state_config.cameras[camera]
x, y, x2, y2 = calculate_region(
frame.shape,
crop = [
camera_config.crop[0],
camera_config.crop[1],
camera_config.crop[2],
camera_config.crop[3],
]
should_run = False
now = datetime.datetime.now().timestamp()
if (
self.model_config.state_config.interval
and now > self.last_run + self.model_config.state_config.interval
):
self.last_run = now
should_run = True
if (
not should_run
and self.model_config.state_config.motion
and any([box_overlaps(crop, mb) for mb in frame_data.get("motion", [])])
):
# classification should run at most once per second
if now > self.last_run + 1:
self.last_run = now
should_run = True
if not should_run:
return
x, y, x2, y2 = calculate_region(
frame.shape,
crop[0],
crop[1],
crop[2],
crop[3],
224,
1.0,
)
@@ -82,12 +119,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
res: np.ndarray = self.interpreter.get_tensor(
self.tensor_output_details[0]["index"]
)[0]
print(f"the gate res is {res}")
probs = res / res.sum(axis=0)
best_id = np.argmax(probs)
score = round(probs[best_id], 2)
print(f"got {self.labelmap[best_id]} with score {score}")
if score >= camera_config.threshold:
self.requestor.send_data(
f"{camera}/classification/{self.name}", self.labelmap[best_id]
)
def handle_request(self, topic, request_data):
return None