Improve async object detector support (#17712)

* Move object detection to folder

* Add input store type

* Add hwnc

* Add hwcn

* Fix test
This commit is contained in:
Nicolas Mowen 2025-04-15 07:55:38 -06:00 committed by GitHub
parent 721f33c857
commit 15fe79178b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 107 additions and 72 deletions

View File

@ -6,7 +6,7 @@ import numpy as np
import frigate.util as util import frigate.util as util
from frigate.config import DetectorTypeEnum from frigate.config import DetectorTypeEnum
from frigate.object_detection import ( from frigate.object_detection.base import (
ObjectDetectProcess, ObjectDetectProcess,
RemoteObjectDetector, RemoteObjectDetector,
load_labels, load_labels,

View File

@ -55,7 +55,7 @@ from frigate.models import (
Timeline, Timeline,
User, User,
) )
from frigate.object_detection import ObjectDetectProcess from frigate.object_detection.base import ObjectDetectProcess
from frigate.output.output import output_frames from frigate.output.output import output_frames
from frigate.ptz.autotrack import PtzAutoTrackerThread from frigate.ptz.autotrack import PtzAutoTrackerThread
from frigate.ptz.onvif import OnvifController from frigate.ptz.onvif import OnvifController

View File

@ -25,6 +25,8 @@ class PixelFormatEnum(str, Enum):
class InputTensorEnum(str, Enum): class InputTensorEnum(str, Enum):
nchw = "nchw" nchw = "nchw"
nhwc = "nhwc" nhwc = "nhwc"
hwnc = "hwnc"
hwcn = "hwcn"
class InputDTypeEnum(str, Enum): class InputDTypeEnum(str, Enum):

View File

@ -1,6 +1,5 @@
import logging import logging
import os import os
import queue
import subprocess import subprocess
import threading import threading
import urllib.request import urllib.request
@ -28,37 +27,11 @@ from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import ( from frigate.detectors.detector_config import (
BaseDetectorConfig, BaseDetectorConfig,
) )
from frigate.object_detection.util import RequestStore, ResponseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ----------------- ResponseStore Class ----------------- #
class ResponseStore:
"""
A thread-safe hash-based response store that maps request IDs
to their results. Threads can wait on the condition variable until
their request's result appears.
"""
def __init__(self):
self.responses = {} # Maps request_id -> (original_input, infer_results)
self.lock = threading.Lock()
self.cond = threading.Condition(self.lock)
def put(self, request_id, response):
with self.cond:
self.responses[request_id] = response
self.cond.notify_all()
def get(self, request_id, timeout=None):
with self.cond:
if not self.cond.wait_for(
lambda: request_id in self.responses, timeout=timeout
):
raise TimeoutError(f"Timeout waiting for response {request_id}")
return self.responses.pop(request_id)
# ----------------- Utility Functions ----------------- # # ----------------- Utility Functions ----------------- #
@ -122,14 +95,14 @@ class HailoAsyncInference:
def __init__( def __init__(
self, self,
hef_path: str, hef_path: str,
input_queue: queue.Queue, input_store: RequestStore,
output_store: ResponseStore, output_store: ResponseStore,
batch_size: int = 1, batch_size: int = 1,
input_type: Optional[str] = None, input_type: Optional[str] = None,
output_type: Optional[Dict[str, str]] = None, output_type: Optional[Dict[str, str]] = None,
send_original_frame: bool = False, send_original_frame: bool = False,
) -> None: ) -> None:
self.input_queue = input_queue self.input_store = input_store
self.output_store = output_store self.output_store = output_store
params = VDevice.create_params() params = VDevice.create_params()
@ -204,9 +177,11 @@ class HailoAsyncInference:
def run(self) -> None: def run(self) -> None:
with self.infer_model.configure() as configured_infer_model: with self.infer_model.configure() as configured_infer_model:
while True: while True:
batch_data = self.input_queue.get() batch_data = self.input_store.get()
if batch_data is None: if batch_data is None:
break break
request_id, frame_data = batch_data request_id, frame_data = batch_data
preprocessed_batch = [frame_data] preprocessed_batch = [frame_data]
request_ids = [request_id] request_ids = [request_id]
@ -274,16 +249,14 @@ class HailoDetector(DetectionApi):
self.working_model_path = self.check_and_prepare() self.working_model_path = self.check_and_prepare()
self.batch_size = 1 self.batch_size = 1
self.input_queue = queue.Queue() self.input_store = RequestStore()
self.response_store = ResponseStore() self.response_store = ResponseStore()
self.request_counter = 0
self.request_counter_lock = threading.Lock()
try: try:
logger.debug(f"[INIT] Loading HEF model from {self.working_model_path}") logger.debug(f"[INIT] Loading HEF model from {self.working_model_path}")
self.inference_engine = HailoAsyncInference( self.inference_engine = HailoAsyncInference(
self.working_model_path, self.working_model_path,
self.input_queue, self.input_store,
self.response_store, self.response_store,
self.batch_size, self.batch_size,
) )
@ -364,26 +337,16 @@ class HailoDetector(DetectionApi):
raise FileNotFoundError(f"Model file not found at: {self.model_path}") raise FileNotFoundError(f"Model file not found at: {self.model_path}")
return cached_model_path return cached_model_path
def _get_request_id(self) -> int:
with self.request_counter_lock:
request_id = self.request_counter
self.request_counter += 1
if self.request_counter > 1000000:
self.request_counter = 0
return request_id
def detect_raw(self, tensor_input): def detect_raw(self, tensor_input):
request_id = self._get_request_id()
tensor_input = self.preprocess(tensor_input) tensor_input = self.preprocess(tensor_input)
if isinstance(tensor_input, np.ndarray) and len(tensor_input.shape) == 3: if isinstance(tensor_input, np.ndarray) and len(tensor_input.shape) == 3:
tensor_input = np.expand_dims(tensor_input, axis=0) tensor_input = np.expand_dims(tensor_input, axis=0)
self.input_queue.put((request_id, tensor_input)) request_id = self.input_store.put(tensor_input)
try: try:
original_input, infer_results = self.response_store.get( _, infer_results = self.response_store.get(request_id, timeout=10.0)
request_id, timeout=10.0
)
except TimeoutError: except TimeoutError:
logger.error( logger.error(
f"Timeout waiting for inference results for request {request_id}" f"Timeout waiting for inference results for request {request_id}"

View File

@ -29,7 +29,7 @@ from frigate.const import (
) )
from frigate.ffmpeg_presets import parse_preset_input from frigate.ffmpeg_presets import parse_preset_input
from frigate.log import LogPipe from frigate.log import LogPipe
from frigate.object_detection import load_labels from frigate.object_detection.base import load_labels
from frigate.util.builtin import get_ffmpeg_arg_list from frigate.util.builtin import get_ffmpeg_arg_list
from frigate.video import start_or_restart_ffmpeg, stop_ffmpeg from frigate.video import start_or_restart_ffmpeg, stop_ffmpeg

View File

@ -15,12 +15,13 @@ from frigate.detectors import create_detector
from frigate.detectors.detector_config import ( from frigate.detectors.detector_config import (
BaseDetectorConfig, BaseDetectorConfig,
InputDTypeEnum, InputDTypeEnum,
InputTensorEnum,
) )
from frigate.util.builtin import EventsPerSecond, load_labels from frigate.util.builtin import EventsPerSecond, load_labels
from frigate.util.image import SharedMemoryFrameManager, UntrackedSharedMemory from frigate.util.image import SharedMemoryFrameManager, UntrackedSharedMemory
from frigate.util.services import listen from frigate.util.services import listen
from .util import tensor_transform
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,14 +31,6 @@ class ObjectDetector(ABC):
pass pass
def tensor_transform(desired_shape: InputTensorEnum):
# Currently this function only supports BHWC permutations
if desired_shape == InputTensorEnum.nhwc:
return None
elif desired_shape == InputTensorEnum.nchw:
return (0, 3, 1, 2)
class LocalObjectDetector(ObjectDetector): class LocalObjectDetector(ObjectDetector):
def __init__( def __init__(
self, self,

View File

@ -0,0 +1,77 @@
"""Object detection utilities."""
import queue
import threading
from numpy import ndarray
from frigate.detectors.detector_config import InputTensorEnum
class RequestStore:
"""
A thread-safe hash-based response store that handles creating requests.
"""
def __init__(self):
self.request_counter = 0
self.request_counter_lock = threading.Lock()
self.input_queue = queue.Queue()
def __get_request_id(self) -> int:
with self.request_counter_lock:
request_id = self.request_counter
self.request_counter += 1
if self.request_counter > 1000000:
self.request_counter = 0
return request_id
def put(self, tensor_input: ndarray) -> int:
request_id = self.__get_request_id()
self.input_queue.get((request_id, tensor_input))
return request_id
def get(self) -> tuple[int, ndarray] | None:
try:
return self.input_queue.get_nowait()
except Exception:
return None
class ResponseStore:
"""
A thread-safe hash-based response store that maps request IDs
to their results. Threads can wait on the condition variable until
their request's result appears.
"""
def __init__(self):
self.responses = {} # Maps request_id -> (original_input, infer_results)
self.lock = threading.Lock()
self.cond = threading.Condition(self.lock)
def put(self, request_id: int, response: ndarray):
with self.cond:
self.responses[request_id] = response
self.cond.notify_all()
def get(self, request_id: int, timeout=None) -> ndarray:
with self.cond:
if not self.cond.wait_for(
lambda: request_id in self.responses, timeout=timeout
):
raise TimeoutError(f"Timeout waiting for response {request_id}")
return self.responses.pop(request_id)
def tensor_transform(desired_shape: InputTensorEnum):
# Currently this function only supports BHWC permutations
if desired_shape == InputTensorEnum.nhwc:
return None
elif desired_shape == InputTensorEnum.nchw:
return (0, 3, 1, 2)
elif desired_shape == InputTensorEnum.hwnc:
return (1, 2, 0, 3)
elif desired_shape == InputTensorEnum.hwcn:
return (1, 2, 3, 0)

View File

@ -15,7 +15,7 @@ from frigate.camera import CameraMetrics
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.const import CACHE_DIR, CLIPS_DIR, RECORD_DIR from frigate.const import CACHE_DIR, CLIPS_DIR, RECORD_DIR
from frigate.data_processing.types import DataProcessorMetrics from frigate.data_processing.types import DataProcessorMetrics
from frigate.object_detection import ObjectDetectProcess from frigate.object_detection.base import ObjectDetectProcess
from frigate.types import StatsTrackingTypes from frigate.types import StatsTrackingTypes
from frigate.util.services import ( from frigate.util.services import (
get_amd_gpu_stats, get_amd_gpu_stats,

View File

@ -5,7 +5,7 @@ import numpy as np
from pydantic import parse_obj_as from pydantic import parse_obj_as
import frigate.detectors as detectors import frigate.detectors as detectors
import frigate.object_detection import frigate.object_detection.base
from frigate.config import DetectorConfig, ModelConfig from frigate.config import DetectorConfig, ModelConfig
from frigate.detectors import DetectorTypeEnum from frigate.detectors import DetectorTypeEnum
from frigate.detectors.detector_config import InputTensorEnum from frigate.detectors.detector_config import InputTensorEnum
@ -23,7 +23,7 @@ class TestLocalObjectDetector(unittest.TestCase):
DetectorConfig, ({"type": det_type, "model": {}}) DetectorConfig, ({"type": det_type, "model": {}})
) )
test_cfg.model.path = "/test/modelpath" test_cfg.model.path = "/test/modelpath"
test_obj = frigate.object_detection.LocalObjectDetector( test_obj = frigate.object_detection.base.LocalObjectDetector(
detector_config=test_cfg detector_config=test_cfg
) )
@ -43,7 +43,7 @@ class TestLocalObjectDetector(unittest.TestCase):
TEST_DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] TEST_DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32]) TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32])
test_obj_detect = frigate.object_detection.LocalObjectDetector( test_obj_detect = frigate.object_detection.base.LocalObjectDetector(
detector_config=parse_obj_as(DetectorConfig, {"type": "cpu", "model": {}}) detector_config=parse_obj_as(DetectorConfig, {"type": "cpu", "model": {}})
) )
@ -70,7 +70,7 @@ class TestLocalObjectDetector(unittest.TestCase):
test_cfg = parse_obj_as(DetectorConfig, {"type": "cpu", "model": {}}) test_cfg = parse_obj_as(DetectorConfig, {"type": "cpu", "model": {}})
test_cfg.model.input_tensor = InputTensorEnum.nchw test_cfg.model.input_tensor = InputTensorEnum.nchw
test_obj_detect = frigate.object_detection.LocalObjectDetector( test_obj_detect = frigate.object_detection.base.LocalObjectDetector(
detector_config=test_cfg detector_config=test_cfg
) )
@ -91,7 +91,7 @@ class TestLocalObjectDetector(unittest.TestCase):
"frigate.detectors.api_types", "frigate.detectors.api_types",
{det_type: Mock() for det_type in DetectorTypeEnum}, {det_type: Mock() for det_type in DetectorTypeEnum},
) )
@patch("frigate.object_detection.load_labels") @patch("frigate.object_detection.base.load_labels")
def test_detect_given_tensor_input_should_return_lfiltered_detections( def test_detect_given_tensor_input_should_return_lfiltered_detections(
self, mock_load_labels self, mock_load_labels
): ):
@ -118,7 +118,7 @@ class TestLocalObjectDetector(unittest.TestCase):
test_cfg = parse_obj_as(DetectorConfig, {"type": "cpu", "model": {}}) test_cfg = parse_obj_as(DetectorConfig, {"type": "cpu", "model": {}})
test_cfg.model = ModelConfig() test_cfg.model = ModelConfig()
test_obj_detect = frigate.object_detection.LocalObjectDetector( test_obj_detect = frigate.object_detection.base.LocalObjectDetector(
detector_config=test_cfg, detector_config=test_cfg,
labels=TEST_LABEL_FILE, labels=TEST_LABEL_FILE,
) )

View File

@ -3,7 +3,7 @@ from typing import TypedDict
from frigate.camera import CameraMetrics from frigate.camera import CameraMetrics
from frigate.data_processing.types import DataProcessorMetrics from frigate.data_processing.types import DataProcessorMetrics
from frigate.object_detection import ObjectDetectProcess from frigate.object_detection.base import ObjectDetectProcess
class StatsTrackingTypes(TypedDict): class StatsTrackingTypes(TypedDict):

View File

@ -24,7 +24,7 @@ from frigate.const import (
from frigate.log import LogPipe from frigate.log import LogPipe
from frigate.motion import MotionDetector from frigate.motion import MotionDetector
from frigate.motion.improved_motion import ImprovedMotionDetector from frigate.motion.improved_motion import ImprovedMotionDetector
from frigate.object_detection import RemoteObjectDetector from frigate.object_detection.base import RemoteObjectDetector
from frigate.ptz.autotrack import ptz_moving_at_frame_time from frigate.ptz.autotrack import ptz_moving_at_frame_time
from frigate.track import ObjectTracker from frigate.track import ObjectTracker
from frigate.track.norfair_tracker import NorfairTracker from frigate.track.norfair_tracker import NorfairTracker

View File

@ -4,7 +4,7 @@ import threading
import time import time
from multiprocessing.synchronize import Event as MpEvent from multiprocessing.synchronize import Event as MpEvent
from frigate.object_detection import ObjectDetectProcess from frigate.object_detection.base import ObjectDetectProcess
from frigate.util.services import restart_frigate from frigate.util.services import restart_frigate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -14,7 +14,7 @@ sys.path.append("/workspace/frigate")
from frigate.config import FrigateConfig # noqa: E402 from frigate.config import FrigateConfig # noqa: E402
from frigate.motion import MotionDetector # noqa: E402 from frigate.motion import MotionDetector # noqa: E402
from frigate.object_detection import LocalObjectDetector # noqa: E402 from frigate.object_detection.base import LocalObjectDetector # noqa: E402
from frigate.track.centroid_tracker import CentroidTracker # noqa: E402 from frigate.track.centroid_tracker import CentroidTracker # noqa: E402
from frigate.track.object_processing import CameraState # noqa: E402 from frigate.track.object_processing import CameraState # noqa: E402
from frigate.util import ( # noqa: E402 from frigate.util import ( # noqa: E402