mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-19 23:08:08 +02:00
* implement deferred real-time processor with background task handling * add tests * fix typing
212 lines
7.3 KiB
Python
212 lines
7.3 KiB
Python
"""Tests for DeferredRealtimeProcessorApi."""
|
|
|
|
import sys
|
|
import time
|
|
import unittest
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
|
|
from frigate.data_processing.real_time.api import DeferredRealtimeProcessorApi
|
|
|
|
# Mock TFLite before importing classification module
|
|
_MOCK_MODULES = [
|
|
"tflite_runtime",
|
|
"tflite_runtime.interpreter",
|
|
"ai_edge_litert",
|
|
"ai_edge_litert.interpreter",
|
|
]
|
|
for mod in _MOCK_MODULES:
|
|
if mod not in sys.modules:
|
|
sys.modules[mod] = MagicMock()
|
|
|
|
from frigate.data_processing.real_time.custom_classification import ( # noqa: E402
|
|
CustomObjectClassificationProcessor,
|
|
)
|
|
|
|
|
|
class StubDeferredProcessor(DeferredRealtimeProcessorApi):
|
|
"""Minimal concrete subclass for testing the deferred base."""
|
|
|
|
def __init__(self, max_queue: int = 8):
|
|
config = MagicMock()
|
|
metrics = MagicMock()
|
|
super().__init__(config, metrics, max_queue=max_queue)
|
|
self.processed_items: list[tuple] = []
|
|
|
|
def process_frame(self, obj_data: dict[str, Any], frame: np.ndarray) -> None:
|
|
"""Enqueue every call — no gating logic in the stub."""
|
|
self._enqueue_task(("frame", obj_data, frame.copy()))
|
|
|
|
def _process_task(self, task: tuple) -> None:
|
|
kind = task[0]
|
|
if kind == "frame":
|
|
_, obj_data, frame = task
|
|
self.processed_items.append((obj_data["id"], frame.shape))
|
|
self._emit_result(
|
|
{
|
|
"type": "test_result",
|
|
"id": obj_data["id"],
|
|
"label": "cat",
|
|
"score": 0.95,
|
|
}
|
|
)
|
|
elif kind == "expire":
|
|
_, object_id = task
|
|
self.processed_items.append(("expired", object_id))
|
|
|
|
def handle_request(
|
|
self, topic: str, request_data: dict[str, Any]
|
|
) -> dict[str, Any] | None:
|
|
if topic == "reload":
|
|
|
|
def _do_reload(data):
|
|
return {"success": True, "model": data.get("name")}
|
|
|
|
return self._enqueue_request(_do_reload, request_data)
|
|
return None
|
|
|
|
def expire_object(self, object_id: str, camera: str) -> None:
|
|
self._enqueue_task(("expire", object_id))
|
|
|
|
|
|
class TestDeferredProcessorBase(unittest.TestCase):
|
|
def test_enqueue_and_drain(self):
|
|
"""Tasks enqueued on main thread are processed by worker, results are drainable."""
|
|
proc = StubDeferredProcessor()
|
|
frame = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
proc.process_frame({"id": "obj1"}, frame)
|
|
proc.process_frame({"id": "obj2"}, frame)
|
|
|
|
# Give the worker time to process
|
|
time.sleep(0.1)
|
|
|
|
results = proc.drain_results()
|
|
self.assertEqual(len(results), 2)
|
|
self.assertEqual(results[0]["id"], "obj1")
|
|
self.assertEqual(results[1]["id"], "obj2")
|
|
|
|
# Second drain should be empty
|
|
self.assertEqual(len(proc.drain_results()), 0)
|
|
|
|
def test_backpressure_drops_tasks(self):
|
|
"""When queue is full, new tasks are silently dropped."""
|
|
proc = StubDeferredProcessor(max_queue=2)
|
|
|
|
frame = np.zeros((10, 10, 3), dtype=np.uint8)
|
|
for i in range(10):
|
|
proc.process_frame({"id": f"obj{i}"}, frame)
|
|
|
|
time.sleep(0.2)
|
|
results = proc.drain_results()
|
|
# The key property: no crash, no unbounded growth
|
|
self.assertLessEqual(len(results), 10)
|
|
self.assertGreater(len(results), 0)
|
|
|
|
def test_handle_request_through_worker(self):
|
|
"""handle_request blocks until the worker processes it and returns a response."""
|
|
proc = StubDeferredProcessor()
|
|
result = proc.handle_request("reload", {"name": "my_model"})
|
|
self.assertEqual(result, {"success": True, "model": "my_model"})
|
|
|
|
def test_expire_object_serialized_with_work(self):
|
|
"""expire_object goes through the queue, serialized with inference work."""
|
|
proc = StubDeferredProcessor()
|
|
frame = np.zeros((10, 10, 3), dtype=np.uint8)
|
|
proc.process_frame({"id": "obj1"}, frame)
|
|
proc.expire_object("obj1", "front_door")
|
|
|
|
time.sleep(0.1)
|
|
# Both should have been processed in order
|
|
self.assertEqual(len(proc.processed_items), 2)
|
|
self.assertEqual(proc.processed_items[0][0], "obj1")
|
|
self.assertEqual(proc.processed_items[1], ("expired", "obj1"))
|
|
|
|
def test_shutdown_joins_worker(self):
|
|
"""shutdown() signals the worker to stop and joins the thread."""
|
|
proc = StubDeferredProcessor()
|
|
proc.shutdown()
|
|
self.assertFalse(proc._worker.is_alive())
|
|
|
|
def test_drain_results_returns_list(self):
|
|
"""drain_results returns a plain list, not a deque."""
|
|
proc = StubDeferredProcessor()
|
|
results = proc.drain_results()
|
|
self.assertIsInstance(results, list)
|
|
|
|
|
|
class TestCustomObjectClassificationDeferred(unittest.TestCase):
|
|
"""Test that CustomObjectClassificationProcessor uses the deferred pattern correctly."""
|
|
|
|
def _make_processor(self):
|
|
config = MagicMock()
|
|
model_config = MagicMock()
|
|
model_config.name = "test_breed"
|
|
model_config.object_config = MagicMock()
|
|
model_config.object_config.objects = ["dog"]
|
|
model_config.threshold = 0.5
|
|
model_config.save_attempts = 10
|
|
model_config.object_config.classification_type = "sub_label"
|
|
publisher = MagicMock()
|
|
requestor = MagicMock()
|
|
metrics = MagicMock()
|
|
metrics.classification_speeds = {}
|
|
metrics.classification_cps = {}
|
|
|
|
with patch.object(
|
|
CustomObjectClassificationProcessor,
|
|
"_CustomObjectClassificationProcessor__build_detector",
|
|
):
|
|
proc = CustomObjectClassificationProcessor(
|
|
config, model_config, publisher, requestor, metrics
|
|
)
|
|
proc.interpreter = None
|
|
proc.tensor_input_details = [{"index": 0}]
|
|
proc.tensor_output_details = [{"index": 0}]
|
|
proc.labelmap = {0: "labrador", 1: "poodle", 2: "none"}
|
|
return proc
|
|
|
|
def test_is_deferred_processor(self):
|
|
"""CustomObjectClassificationProcessor should be a DeferredRealtimeProcessorApi."""
|
|
proc = self._make_processor()
|
|
self.assertIsInstance(proc, DeferredRealtimeProcessorApi)
|
|
|
|
def test_expire_clears_history(self):
|
|
"""expire_object should clear classification history for the object."""
|
|
proc = self._make_processor()
|
|
proc.classification_history["obj1"] = [("labrador", 0.9, 1.0)]
|
|
|
|
proc.expire_object("obj1", "front")
|
|
time.sleep(0.1)
|
|
|
|
self.assertNotIn("obj1", proc.classification_history)
|
|
|
|
def test_drain_results_empty_when_no_model(self):
|
|
"""With no interpreter, process_frame saves training images but emits no results."""
|
|
proc = self._make_processor()
|
|
proc.interpreter = None
|
|
|
|
frame = np.zeros((150, 100), dtype=np.uint8)
|
|
obj_data = {
|
|
"id": "obj1",
|
|
"label": "dog",
|
|
"false_positive": False,
|
|
"end_time": None,
|
|
"box": [10, 10, 50, 50],
|
|
"camera": "front",
|
|
}
|
|
|
|
with patch(
|
|
"frigate.data_processing.real_time.custom_classification.write_classification_attempt"
|
|
):
|
|
proc.process_frame(obj_data, frame)
|
|
|
|
time.sleep(0.1)
|
|
results = proc.drain_results()
|
|
self.assertEqual(len(results), 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|