From a47be12ac54ac6613ab72715c7c1b12e753f2a90 Mon Sep 17 00:00:00 2001 From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com> Date: Tue, 14 Apr 2026 22:39:44 -0500 Subject: [PATCH] Add deferred real-time processor for enrichments (#22880) * implement deferred real-time processor with background task handling * add tests * fix typing --- frigate/data_processing/real_time/api.py | 126 ++++++++- .../real_time/custom_classification.py | 248 ++++++++++-------- frigate/embeddings/maintainer.py | 84 +++++- frigate/test/test_deferred_processor.py | 211 +++++++++++++++ 4 files changed, 546 insertions(+), 123 deletions(-) create mode 100644 frigate/test/test_deferred_processor.py diff --git a/frigate/data_processing/real_time/api.py b/frigate/data_processing/real_time/api.py index 31127220f..b9b7ba26e 100644 --- a/frigate/data_processing/real_time/api.py +++ b/frigate/data_processing/real_time/api.py @@ -1,8 +1,12 @@ """Local only processors for handling real time object processing.""" import logging +import threading from abc import ABC, abstractmethod -from typing import Any +from collections import deque +from concurrent.futures import Future +from queue import Empty, Full, Queue +from typing import Any, Callable import numpy as np @@ -74,3 +78,123 @@ class RealTimeProcessorApi(ABC): payload: The updated configuration object. """ pass + + def drain_results(self) -> list[dict[str, Any]]: + """Return pending results that need IPC side-effects. + + Deferred processors accumulate results on a worker thread. + The maintainer calls this each loop iteration to collect them + and perform publishes on the main thread. + + Synchronous processors return an empty list (default). + """ + return [] + + def shutdown(self) -> None: + """Stop any background work and release resources. + + Called when the processor is being removed or the maintainer + is shutting down. Default is a no-op for synchronous processors. + """ + pass + + +class DeferredRealtimeProcessorApi(RealTimeProcessorApi): + """Base class for processors that offload heavy work to a background thread. + + Subclasses implement: + - process_frame(): do cheap gating + crop + copy, then call _enqueue_task() + - _process_task(task): heavy work (inference, consensus) on the worker thread + - handle_request(): optionally use _enqueue_request() for sync request/response + - expire_object(): call _enqueue_task() with a control message + + The worker thread owns all processor state. No locks are needed because + only the worker mutates state. Results that need IPC are placed in + _pending_results via _emit_result(), and the maintainer drains them + each loop iteration. + """ + + def __init__( + self, + config: FrigateConfig, + metrics: DataProcessorMetrics, + max_queue: int = 8, + ) -> None: + super().__init__(config, metrics) + self._task_queue: Queue = Queue(maxsize=max_queue) + self._pending_results: deque[dict[str, Any]] = deque() + self._results_lock = threading.Lock() + self._stop_event = threading.Event() + self._worker = threading.Thread( + target=self._drain_loop, + daemon=True, + name=f"{type(self).__name__}_worker", + ) + self._worker.start() + + def _drain_loop(self) -> None: + """Worker thread main loop — drains the task queue until stopped.""" + while not self._stop_event.is_set(): + try: + task = self._task_queue.get(timeout=0.5) + except Empty: + continue + + if ( + isinstance(task, tuple) + and len(task) == 2 + and isinstance(task[1], Future) + ): + # Request/response: (callable_and_args, future) + (func, args), future = task + try: + result = func(args) + future.set_result(result) + except Exception as e: + future.set_exception(e) + else: + try: + self._process_task(task) + except Exception: + logger.exception("Error processing deferred task") + + def _enqueue_task(self, task: Any) -> bool: + """Enqueue a task for the worker. Returns False if queue is full (dropped).""" + try: + self._task_queue.put_nowait(task) + return True + except Full: + logger.debug("Deferred processor queue full, dropping task") + return False + + def _enqueue_request(self, func: Callable, args: Any, timeout: float = 10.0) -> Any: + """Enqueue a request and block until the worker returns a result.""" + future: Future = Future() + self._task_queue.put(((func, args), future), timeout=timeout) + return future.result(timeout=timeout) + + def _emit_result(self, result: dict[str, Any]) -> None: + """Called by the worker thread to stage a result for the maintainer.""" + with self._results_lock: + self._pending_results.append(result) + + def drain_results(self) -> list[dict[str, Any]]: + """Called by the maintainer on the main thread to collect pending results.""" + with self._results_lock: + results = list(self._pending_results) + self._pending_results.clear() + return results + + def shutdown(self) -> None: + """Signal the worker to stop and wait for it to finish.""" + self._stop_event.set() + self._worker.join(timeout=5.0) + + @abstractmethod + def _process_task(self, task: Any) -> None: + """Process a single task on the worker thread. + + Subclasses implement inference, consensus, training image saves here. + Call _emit_result() to stage results for the maintainer to publish. + """ + pass diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 1dcf59052..e3b0e23ed 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -1,7 +1,6 @@ """Real time processor that works with classification tflite models.""" import datetime -import json import logging import os from typing import Any @@ -10,25 +9,18 @@ import cv2 import numpy as np from frigate.comms.embeddings_updater import EmbeddingsRequestEnum -from frigate.comms.event_metadata_updater import ( - EventMetadataPublisher, - EventMetadataTypeEnum, -) +from frigate.comms.event_metadata_updater import EventMetadataPublisher from frigate.comms.inter_process import InterProcessRequestor from frigate.config import FrigateConfig -from frigate.config.classification import ( - CustomClassificationConfig, - ObjectClassificationType, -) +from frigate.config.classification import CustomClassificationConfig from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR from frigate.log import suppress_stderr_during -from frigate.types import TrackedObjectUpdateTypesEnum from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels from frigate.util.image import calculate_region from frigate.util.object import box_overlaps from ..types import DataProcessorMetrics -from .api import RealTimeProcessorApi +from .api import DeferredRealtimeProcessorApi try: from tflite_runtime.interpreter import Interpreter @@ -40,7 +32,7 @@ logger = logging.getLogger(__name__) MAX_OBJECT_CLASSIFICATIONS = 16 -class CustomStateClassificationProcessor(RealTimeProcessorApi): +class CustomStateClassificationProcessor(DeferredRealtimeProcessorApi): def __init__( self, config: FrigateConfig, @@ -48,7 +40,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): requestor: InterProcessRequestor, metrics: DataProcessorMetrics, ): - super().__init__(config, metrics) + super().__init__(config, metrics, max_queue=4) self.model_config = model_config if not self.model_config.name: @@ -259,14 +251,34 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): ) return - frame = rgb[y1:y2, x1:x2] + cropped_frame = rgb[y1:y2, x1:x2] try: - resized_frame = cv2.resize(frame, (224, 224)) + resized_frame = cv2.resize(cropped_frame, (224, 224)) except Exception: logger.warning("Failed to resize image for state classification") return + # Copy for training image saves on worker thread + crop_bgr = cv2.cvtColor(cropped_frame, cv2.COLOR_RGB2BGR) + + self._enqueue_task(("classify", camera, now, resized_frame, crop_bgr)) + + def _process_task(self, task: Any) -> None: + kind = task[0] + if kind == "classify": + _, camera, timestamp, resized_frame, crop_bgr = task + self._classify_state(camera, timestamp, resized_frame, crop_bgr) + elif kind == "reload": + self.__build_detector() + + def _classify_state( + self, + camera: str, + timestamp: float, + resized_frame: np.ndarray, + crop_bgr: np.ndarray, + ) -> None: if self.interpreter is None: # When interpreter is None, always save (score is 0.0, which is < 1.0) if self._should_save_image(camera, "unknown", 0.0): @@ -277,15 +289,18 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): ) write_classification_attempt( self.train_dir, - cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), + crop_bgr, "none-none", - now, + timestamp, "unknown", 0.0, max_files=save_attempts, ) return + if not self.tensor_input_details or not self.tensor_output_details: + return + input = np.expand_dims(resized_frame, axis=0) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.invoke() @@ -298,7 +313,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): ) best_id = int(np.argmax(probs)) score = round(probs[best_id], 2) - self.__update_metrics(datetime.datetime.now().timestamp() - now) + self.__update_metrics(datetime.datetime.now().timestamp() - timestamp) detected_state = self.labelmap[best_id] @@ -310,9 +325,9 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): ) write_classification_attempt( self.train_dir, - cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), + crop_bgr, "none-none", - now, + timestamp, detected_state, score, max_files=save_attempts, @@ -327,9 +342,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): verified_state = self.verify_state_change(camera, detected_state) if verified_state is not None: - self.requestor.send_data( - f"{camera}/classification/{self.model_config.name}", - verified_state, + self._emit_result( + { + "type": "classification", + "processor": "state", + "model_name": self.model_config.name, + "camera": camera, + "state": verified_state, + } ) def handle_request( @@ -337,14 +357,19 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): ) -> dict[str, Any] | None: if topic == EmbeddingsRequestEnum.reload_classification_model.value: if request_data.get("model_name") == self.model_config.name: - self.__build_detector() - logger.info( - f"Successfully loaded updated model for {self.model_config.name}" - ) - return { - "success": True, - "message": f"Loaded {self.model_config.name} model.", - } + + def _do_reload(data: dict[str, Any]) -> dict[str, Any]: + self.__build_detector() + logger.info( + f"Successfully loaded updated model for {self.model_config.name}" + ) + return { + "success": True, + "message": f"Loaded {self.model_config.name} model.", + } + + result: dict[str, Any] = self._enqueue_request(_do_reload, request_data) + return result else: return None else: @@ -354,7 +379,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): pass -class CustomObjectClassificationProcessor(RealTimeProcessorApi): +class CustomObjectClassificationProcessor(DeferredRealtimeProcessorApi): def __init__( self, config: FrigateConfig, @@ -363,7 +388,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): requestor: InterProcessRequestor, metrics: DataProcessorMetrics, ): - super().__init__(config, metrics) + super().__init__(config, metrics, max_queue=8) self.model_config = model_config if not self.model_config.name: @@ -536,18 +561,41 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): ) rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420) - crop = rgb[ - y:y2, - x:x2, - ] + crop = rgb[y:y2, x:x2] - if crop.shape != (224, 224): - try: - resized_crop = cv2.resize(crop, (224, 224)) - except Exception: - logger.warning("Failed to resize image for state classification") - return + try: + resized_crop = cv2.resize(crop, (224, 224)) + except Exception: + logger.warning("Failed to resize image for object classification") + return + # Copy crop for training images (will be used on worker thread) + crop_bgr = cv2.cvtColor(crop, cv2.COLOR_RGB2BGR) + + self._enqueue_task( + ("classify", object_id, obj_data["camera"], now, resized_crop, crop_bgr) + ) + + def _process_task(self, task: Any) -> None: + kind = task[0] + if kind == "classify": + _, object_id, camera, timestamp, resized_crop, crop_bgr = task + self._classify_object(object_id, camera, timestamp, resized_crop, crop_bgr) + elif kind == "expire": + _, object_id = task + if object_id in self.classification_history: + self.classification_history.pop(object_id) + elif kind == "reload": + self.__build_detector() + + def _classify_object( + self, + object_id: str, + camera: str, + timestamp: float, + resized_crop: np.ndarray, + crop_bgr: np.ndarray, + ) -> None: if self.interpreter is None: save_attempts = ( self.model_config.save_attempts @@ -556,9 +604,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): ) write_classification_attempt( self.train_dir, - cv2.cvtColor(crop, cv2.COLOR_RGB2BGR), + crop_bgr, object_id, - now, + timestamp, "unknown", 0.0, max_files=save_attempts, @@ -569,7 +617,10 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): if object_id not in self.classification_history: self.classification_history[object_id] = [] - self.classification_history[object_id].append(("unknown", 0.0, now)) + self.classification_history[object_id].append(("unknown", 0.0, timestamp)) + return + + if not self.tensor_input_details or not self.tensor_output_details: return input = np.expand_dims(resized_crop, axis=0) @@ -584,7 +635,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): ) best_id = int(np.argmax(probs)) score = round(probs[best_id], 2) - self.__update_metrics(datetime.datetime.now().timestamp() - now) + self.__update_metrics(datetime.datetime.now().timestamp() - timestamp) save_attempts = ( self.model_config.save_attempts @@ -593,9 +644,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): ) write_classification_attempt( self.train_dir, - cv2.cvtColor(crop, cv2.COLOR_RGB2BGR), + crop_bgr, object_id, - now, + timestamp, self.labelmap[best_id], score, max_files=save_attempts, @@ -610,92 +661,57 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): sub_label = self.labelmap[best_id] logger.debug( - f"{self.model_config.name}: Object {object_id} (label={obj_data['label']}) passed threshold with sub_label={sub_label}, score={score}" + f"{self.model_config.name}: Object {object_id} passed threshold with sub_label={sub_label}, score={score}" ) consensus_label, consensus_score = self.get_weighted_score( - object_id, sub_label, score, now + object_id, sub_label, score, timestamp ) logger.debug( f"{self.model_config.name}: get_weighted_score returned consensus_label={consensus_label}, consensus_score={consensus_score} for {object_id}" ) - if consensus_label is not None: - camera = obj_data["camera"] - logger.debug( - f"{self.model_config.name}: Publishing sub_label={consensus_label} for {obj_data['label']} object {object_id} on {camera}" + if consensus_label is not None and self.model_config.object_config is not None: + self._emit_result( + { + "type": "classification", + "processor": "object", + "model_name": self.model_config.name, + "classification_type": self.model_config.object_config.classification_type, + "object_id": object_id, + "camera": camera, + "timestamp": timestamp, + "label": consensus_label, + "score": consensus_score, + } ) - if ( - self.model_config.object_config.classification_type - == ObjectClassificationType.sub_label - ): - self.sub_label_publisher.publish( - (object_id, consensus_label, consensus_score), - EventMetadataTypeEnum.sub_label, - ) - self.requestor.send_data( - "tracked_object_update", - json.dumps( - { - "type": TrackedObjectUpdateTypesEnum.classification, - "id": object_id, - "camera": camera, - "timestamp": now, - "model": self.model_config.name, - "sub_label": consensus_label, - "score": consensus_score, - } - ), - ) - elif ( - self.model_config.object_config.classification_type - == ObjectClassificationType.attribute - ): - self.sub_label_publisher.publish( - ( - object_id, - self.model_config.name, - consensus_label, - consensus_score, - ), - EventMetadataTypeEnum.attribute.value, - ) - self.requestor.send_data( - "tracked_object_update", - json.dumps( - { - "type": TrackedObjectUpdateTypesEnum.classification, - "id": object_id, - "camera": camera, - "timestamp": now, - "model": self.model_config.name, - "attribute": consensus_label, - "score": consensus_score, - } - ), - ) - - def handle_request(self, topic: str, request_data: dict) -> dict | None: + def handle_request( + self, topic: str, request_data: dict[str, Any] + ) -> dict[str, Any] | None: if topic == EmbeddingsRequestEnum.reload_classification_model.value: if request_data.get("model_name") == self.model_config.name: - self.__build_detector() - logger.info( - f"Successfully loaded updated model for {self.model_config.name}" - ) - return { - "success": True, - "message": f"Loaded {self.model_config.name} model.", - } + + def _do_reload(data: dict[str, Any]) -> dict[str, Any]: + self.__build_detector() + logger.info( + f"Successfully loaded updated model for {self.model_config.name}" + ) + return { + "success": True, + "message": f"Loaded {self.model_config.name} model.", + } + + result: dict[str, Any] = self._enqueue_request(_do_reload, request_data) + return result else: return None else: return None def expire_object(self, object_id: str, camera: str) -> None: - if object_id in self.classification_history: - self.classification_history.pop(object_id) + self._enqueue_task(("expire", object_id)) def write_classification_attempt( diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index 3f066a860..ea1c9a118 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -2,6 +2,7 @@ import base64 import datetime +import json import logging import threading from multiprocessing.synchronize import Event as MpEvent @@ -33,6 +34,7 @@ from frigate.config.camera.updater import ( CameraConfigUpdateEnum, CameraConfigUpdateSubscriber, ) +from frigate.config.classification import ObjectClassificationType from frigate.data_processing.common.license_plate.model import ( LicensePlateModelRunner, ) @@ -61,6 +63,7 @@ from frigate.db.sqlitevecq import SqliteVecQueueDatabase from frigate.events.types import EventTypeEnum, RegenerateDescriptionEnum from frigate.genai import GenAIClientManager from frigate.models import Event, Recordings, ReviewSegment, Trigger +from frigate.types import TrackedObjectUpdateTypesEnum from frigate.util.builtin import serialize from frigate.util.file import get_event_thumbnail_bytes from frigate.util.image import SharedMemoryFrameManager @@ -274,10 +277,15 @@ class EmbeddingMaintainer(threading.Thread): self._process_recordings_updates() self._process_review_updates() self._process_frame_updates() + self._process_deferred_results() self._expire_dedicated_lpr() self._process_finalized() self._process_event_metadata() + # Shutdown deferred processors + for processor in self.realtime_processors: + processor.shutdown() + self.config_updater.stop() self.enrichment_config_subscriber.stop() self.event_subscriber.stop() @@ -316,10 +324,9 @@ class EmbeddingMaintainer(threading.Thread): model_name = topic.split("/")[-1] if model_config is None: - self.realtime_processors = [ - processor - for processor in self.realtime_processors - if not ( + remaining = [] + for processor in self.realtime_processors: + if ( isinstance( processor, ( @@ -328,8 +335,11 @@ class EmbeddingMaintainer(threading.Thread): ), ) and processor.model_config.name == model_name - ) - ] + ): + processor.shutdown() + else: + remaining.append(processor) + self.realtime_processors = remaining logger.info( f"Successfully removed classification processor for model: {model_name}" @@ -697,6 +707,68 @@ class EmbeddingMaintainer(threading.Thread): self.frame_manager.close(frame_name) + def _process_deferred_results(self) -> None: + """Drain results from deferred processors and perform IPC side-effects.""" + for processor in self.realtime_processors: + results = processor.drain_results() + + for result in results: + if result.get("type") != "classification": + continue + + if result["processor"] == "state": + self.requestor.send_data( + f"{result['camera']}/classification/{result['model_name']}", + result["state"], + ) + elif result["processor"] == "object": + object_id = result["object_id"] + camera = result["camera"] + timestamp = result["timestamp"] + model_name = result["model_name"] + label = result["label"] + score = result["score"] + classification_type = result["classification_type"] + + if classification_type == ObjectClassificationType.sub_label: + self.event_metadata_publisher.publish( + (object_id, label, score), + EventMetadataTypeEnum.sub_label, + ) + self.requestor.send_data( + "tracked_object_update", + json.dumps( + { + "type": TrackedObjectUpdateTypesEnum.classification, + "id": object_id, + "camera": camera, + "timestamp": timestamp, + "model": model_name, + "sub_label": label, + "score": score, + } + ), + ) + elif classification_type == ObjectClassificationType.attribute: + self.event_metadata_publisher.publish( + (object_id, model_name, label, score), + EventMetadataTypeEnum.attribute.value, + ) + self.requestor.send_data( + "tracked_object_update", + json.dumps( + { + "type": TrackedObjectUpdateTypesEnum.classification, + "id": object_id, + "camera": camera, + "timestamp": timestamp, + "model": model_name, + "attribute": label, + "score": score, + } + ), + ) + def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None: """Embed the thumbnail for an event.""" if not self.config.semantic_search.enabled: diff --git a/frigate/test/test_deferred_processor.py b/frigate/test/test_deferred_processor.py new file mode 100644 index 000000000..c76b445fa --- /dev/null +++ b/frigate/test/test_deferred_processor.py @@ -0,0 +1,211 @@ +"""Tests for DeferredRealtimeProcessorApi.""" + +import sys +import time +import unittest +from typing import Any +from unittest.mock import MagicMock, patch + +import numpy as np + +from frigate.data_processing.real_time.api import DeferredRealtimeProcessorApi + +# Mock TFLite before importing classification module +_MOCK_MODULES = [ + "tflite_runtime", + "tflite_runtime.interpreter", + "ai_edge_litert", + "ai_edge_litert.interpreter", +] +for mod in _MOCK_MODULES: + if mod not in sys.modules: + sys.modules[mod] = MagicMock() + +from frigate.data_processing.real_time.custom_classification import ( # noqa: E402 + CustomObjectClassificationProcessor, +) + + +class StubDeferredProcessor(DeferredRealtimeProcessorApi): + """Minimal concrete subclass for testing the deferred base.""" + + def __init__(self, max_queue: int = 8): + config = MagicMock() + metrics = MagicMock() + super().__init__(config, metrics, max_queue=max_queue) + self.processed_items: list[tuple] = [] + + def process_frame(self, obj_data: dict[str, Any], frame: np.ndarray) -> None: + """Enqueue every call — no gating logic in the stub.""" + self._enqueue_task(("frame", obj_data, frame.copy())) + + def _process_task(self, task: tuple) -> None: + kind = task[0] + if kind == "frame": + _, obj_data, frame = task + self.processed_items.append((obj_data["id"], frame.shape)) + self._emit_result( + { + "type": "test_result", + "id": obj_data["id"], + "label": "cat", + "score": 0.95, + } + ) + elif kind == "expire": + _, object_id = task + self.processed_items.append(("expired", object_id)) + + def handle_request( + self, topic: str, request_data: dict[str, Any] + ) -> dict[str, Any] | None: + if topic == "reload": + + def _do_reload(data): + return {"success": True, "model": data.get("name")} + + return self._enqueue_request(_do_reload, request_data) + return None + + def expire_object(self, object_id: str, camera: str) -> None: + self._enqueue_task(("expire", object_id)) + + +class TestDeferredProcessorBase(unittest.TestCase): + def test_enqueue_and_drain(self): + """Tasks enqueued on main thread are processed by worker, results are drainable.""" + proc = StubDeferredProcessor() + frame = np.zeros((100, 100, 3), dtype=np.uint8) + proc.process_frame({"id": "obj1"}, frame) + proc.process_frame({"id": "obj2"}, frame) + + # Give the worker time to process + time.sleep(0.1) + + results = proc.drain_results() + self.assertEqual(len(results), 2) + self.assertEqual(results[0]["id"], "obj1") + self.assertEqual(results[1]["id"], "obj2") + + # Second drain should be empty + self.assertEqual(len(proc.drain_results()), 0) + + def test_backpressure_drops_tasks(self): + """When queue is full, new tasks are silently dropped.""" + proc = StubDeferredProcessor(max_queue=2) + + frame = np.zeros((10, 10, 3), dtype=np.uint8) + for i in range(10): + proc.process_frame({"id": f"obj{i}"}, frame) + + time.sleep(0.2) + results = proc.drain_results() + # The key property: no crash, no unbounded growth + self.assertLessEqual(len(results), 10) + self.assertGreater(len(results), 0) + + def test_handle_request_through_worker(self): + """handle_request blocks until the worker processes it and returns a response.""" + proc = StubDeferredProcessor() + result = proc.handle_request("reload", {"name": "my_model"}) + self.assertEqual(result, {"success": True, "model": "my_model"}) + + def test_expire_object_serialized_with_work(self): + """expire_object goes through the queue, serialized with inference work.""" + proc = StubDeferredProcessor() + frame = np.zeros((10, 10, 3), dtype=np.uint8) + proc.process_frame({"id": "obj1"}, frame) + proc.expire_object("obj1", "front_door") + + time.sleep(0.1) + # Both should have been processed in order + self.assertEqual(len(proc.processed_items), 2) + self.assertEqual(proc.processed_items[0][0], "obj1") + self.assertEqual(proc.processed_items[1], ("expired", "obj1")) + + def test_shutdown_joins_worker(self): + """shutdown() signals the worker to stop and joins the thread.""" + proc = StubDeferredProcessor() + proc.shutdown() + self.assertFalse(proc._worker.is_alive()) + + def test_drain_results_returns_list(self): + """drain_results returns a plain list, not a deque.""" + proc = StubDeferredProcessor() + results = proc.drain_results() + self.assertIsInstance(results, list) + + +class TestCustomObjectClassificationDeferred(unittest.TestCase): + """Test that CustomObjectClassificationProcessor uses the deferred pattern correctly.""" + + def _make_processor(self): + config = MagicMock() + model_config = MagicMock() + model_config.name = "test_breed" + model_config.object_config = MagicMock() + model_config.object_config.objects = ["dog"] + model_config.threshold = 0.5 + model_config.save_attempts = 10 + model_config.object_config.classification_type = "sub_label" + publisher = MagicMock() + requestor = MagicMock() + metrics = MagicMock() + metrics.classification_speeds = {} + metrics.classification_cps = {} + + with patch.object( + CustomObjectClassificationProcessor, + "_CustomObjectClassificationProcessor__build_detector", + ): + proc = CustomObjectClassificationProcessor( + config, model_config, publisher, requestor, metrics + ) + proc.interpreter = None + proc.tensor_input_details = [{"index": 0}] + proc.tensor_output_details = [{"index": 0}] + proc.labelmap = {0: "labrador", 1: "poodle", 2: "none"} + return proc + + def test_is_deferred_processor(self): + """CustomObjectClassificationProcessor should be a DeferredRealtimeProcessorApi.""" + proc = self._make_processor() + self.assertIsInstance(proc, DeferredRealtimeProcessorApi) + + def test_expire_clears_history(self): + """expire_object should clear classification history for the object.""" + proc = self._make_processor() + proc.classification_history["obj1"] = [("labrador", 0.9, 1.0)] + + proc.expire_object("obj1", "front") + time.sleep(0.1) + + self.assertNotIn("obj1", proc.classification_history) + + def test_drain_results_empty_when_no_model(self): + """With no interpreter, process_frame saves training images but emits no results.""" + proc = self._make_processor() + proc.interpreter = None + + frame = np.zeros((150, 100), dtype=np.uint8) + obj_data = { + "id": "obj1", + "label": "dog", + "false_positive": False, + "end_time": None, + "box": [10, 10, 50, 50], + "camera": "front", + } + + with patch( + "frigate.data_processing.real_time.custom_classification.write_classification_attempt" + ): + proc.process_frame(obj_data, frame) + + time.sleep(0.1) + results = proc.drain_results() + self.assertEqual(len(results), 0) + + +if __name__ == "__main__": + unittest.main()