mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-08-08 13:51:01 +02:00
Add AsyncLocalObjectDetector class
This commit is contained in:
parent
d4e9de000e
commit
18f1cc1638
@ -19,7 +19,7 @@ from pydantic import BaseModel, Field
|
|||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
from frigate.detectors.detection_api import DetectionApi
|
from frigate.detectors.detection_api import DetectionApi
|
||||||
from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum
|
from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum
|
||||||
from frigate.util.model import __post_process_multipart_yolo
|
from frigate.util.model import post_process_yolo
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ class MemryXDetector(DetectionApi):
|
|||||||
|
|
||||||
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
||||||
self.model_url = (
|
self.model_url = (
|
||||||
"https://developer.memryx.com/example_files/1p2_frigate/yolov9.zip"
|
"https://developer.memryx.com/example_files/1p2_frigate/yolo-generic.zip"
|
||||||
)
|
)
|
||||||
self.expected_dfp_model = (
|
self.expected_dfp_model = (
|
||||||
"YOLO_v9_small_640_640_3_onnx.dfp"
|
"YOLO_v9_small_640_640_3_onnx.dfp"
|
||||||
@ -214,6 +214,7 @@ class MemryXDetector(DetectionApi):
|
|||||||
if self.memx_model_type == ModelTypeEnum.yolox:
|
if self.memx_model_type == ModelTypeEnum.yolox:
|
||||||
tensor_input = tensor_input.squeeze(0)
|
tensor_input = tensor_input.squeeze(0)
|
||||||
|
|
||||||
|
tensor_input = tensor_input * 255.0
|
||||||
padded_img = np.ones((640, 640, 3), dtype=np.uint8) * 114
|
padded_img = np.ones((640, 640, 3), dtype=np.uint8) * 114
|
||||||
|
|
||||||
scale = min(
|
scale = min(
|
||||||
@ -238,10 +239,11 @@ class MemryXDetector(DetectionApi):
|
|||||||
# Step 5: Concatenate along the channel dimension (axis 2)
|
# Step 5: Concatenate along the channel dimension (axis 2)
|
||||||
concatenated_img = np.concatenate([x0, x1, x2, x3], axis=2)
|
concatenated_img = np.concatenate([x0, x1, x2, x3], axis=2)
|
||||||
processed_input = concatenated_img.astype(np.float32)
|
processed_input = concatenated_img.astype(np.float32)
|
||||||
else:
|
|
||||||
processed_input = tensor_input.astype(np.float32) / 255.0 # Normalize
|
else:
|
||||||
# Assuming original input is always NHWC and MemryX wants HWNC:
|
tensor_input = tensor_input.squeeze(0) # (H, W, C)
|
||||||
processed_input = processed_input.transpose(1, 2, 0, 3) # NHWC -> HWNC
|
# Add axis=2 to create Z=1: (H, W, Z=1, C)
|
||||||
|
processed_input = np.expand_dims(tensor_input, axis=2) # Now (H, W, 1, 3)
|
||||||
|
|
||||||
# Send frame to MemryX for processing
|
# Send frame to MemryX for processing
|
||||||
self.capture_queue.put(processed_input)
|
self.capture_queue.put(processed_input)
|
||||||
@ -596,7 +598,7 @@ class MemryXDetector(DetectionApi):
|
|||||||
sigmoid_output = self.sigmoid(split_1)
|
sigmoid_output = self.sigmoid(split_1)
|
||||||
outputs = self.onnx_concat([mul_output, sigmoid_output], axis=1)
|
outputs = self.onnx_concat([mul_output, sigmoid_output], axis=1)
|
||||||
|
|
||||||
final_detections = __post_process_multipart_yolo(
|
final_detections = post_process_yolo(
|
||||||
outputs, self.memx_model_width, self.memx_model_height
|
outputs, self.memx_model_width, self.memx_model_height
|
||||||
)
|
)
|
||||||
self.output_queue.put(final_detections)
|
self.output_queue.put(final_detections)
|
||||||
@ -617,4 +619,4 @@ class MemryXDetector(DetectionApi):
|
|||||||
|
|
||||||
def detect_raw(self, tensor_input: np.ndarray):
|
def detect_raw(self, tensor_input: np.ndarray):
|
||||||
"""Removed synchronous detect_raw() function so that we only use async"""
|
"""Removed synchronous detect_raw() function so that we only use async"""
|
||||||
return 0
|
return 0
|
@ -7,6 +7,8 @@ import queue
|
|||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from multiprocessing import Queue, Value
|
||||||
|
from multiprocessing.synchronize import Event as MpEvent
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from setproctitle import setproctitle
|
from setproctitle import setproctitle
|
||||||
@ -16,6 +18,7 @@ from frigate.detectors import create_detector
|
|||||||
from frigate.detectors.detector_config import (
|
from frigate.detectors.detector_config import (
|
||||||
BaseDetectorConfig,
|
BaseDetectorConfig,
|
||||||
InputDTypeEnum,
|
InputDTypeEnum,
|
||||||
|
ModelConfig,
|
||||||
)
|
)
|
||||||
from frigate.util.builtin import EventsPerSecond, load_labels
|
from frigate.util.builtin import EventsPerSecond, load_labels
|
||||||
from frigate.util.image import SharedMemoryFrameManager, UntrackedSharedMemory
|
from frigate.util.image import SharedMemoryFrameManager, UntrackedSharedMemory
|
||||||
@ -32,7 +35,7 @@ class ObjectDetector(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LocalObjectDetector(ObjectDetector):
|
class BaseLocalDetector(ObjectDetector):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
detector_config: BaseDetectorConfig = None,
|
detector_config: BaseDetectorConfig = None,
|
||||||
@ -54,6 +57,18 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
|
|
||||||
self.detect_api = create_detector(detector_config)
|
self.detect_api = create_detector(detector_config)
|
||||||
|
|
||||||
|
def _transform_input(self, tensor_input: np.ndarray) -> np.ndarray:
|
||||||
|
if self.input_transform:
|
||||||
|
tensor_input = np.transpose(tensor_input, self.input_transform)
|
||||||
|
|
||||||
|
if self.dtype == InputDTypeEnum.float:
|
||||||
|
tensor_input = tensor_input.astype(np.float32)
|
||||||
|
tensor_input /= 255
|
||||||
|
elif self.dtype == InputDTypeEnum.float_denorm:
|
||||||
|
tensor_input = tensor_input.astype(np.float32)
|
||||||
|
|
||||||
|
return tensor_input
|
||||||
|
|
||||||
def detect(self, tensor_input: np.ndarray, threshold=0.4):
|
def detect(self, tensor_input: np.ndarray, threshold=0.4):
|
||||||
detections = []
|
detections = []
|
||||||
|
|
||||||
@ -71,25 +86,30 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
self.fps.update()
|
self.fps.update()
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
|
|
||||||
|
class LocalObjectDetector(BaseLocalDetector):
|
||||||
def detect_raw(self, tensor_input: np.ndarray):
|
def detect_raw(self, tensor_input: np.ndarray):
|
||||||
if self.input_transform:
|
tensor_input = self._transform_input(tensor_input)
|
||||||
tensor_input = np.transpose(tensor_input, self.input_transform)
|
|
||||||
|
|
||||||
if self.dtype == InputDTypeEnum.float:
|
|
||||||
tensor_input = tensor_input.astype(np.float32)
|
|
||||||
tensor_input /= 255
|
|
||||||
|
|
||||||
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
||||||
|
|
||||||
|
|
||||||
def prepare_detector(name, detector_config, out_events):
|
class AsyncLocalObjectDetector(BaseLocalDetector):
|
||||||
|
def async_send_input(self, tensor_input: np.ndarray, connection_id):
|
||||||
|
tensor_input = self._transform_input(tensor_input)
|
||||||
|
return self.detect_api.send_input(connection_id, tensor_input)
|
||||||
|
|
||||||
|
def async_receive_output(self):
|
||||||
|
return self.detect_api.receive_output()
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_detector(name, out_events):
|
||||||
threading.current_thread().name = f"detector:{name}"
|
threading.current_thread().name = f"detector:{name}"
|
||||||
logger = logging.getLogger(f"detector.{name}")
|
logger = logging.getLogger(f"detector.{name}")
|
||||||
logger.info(f"Starting detection process: {os.getpid()}")
|
logger.info(f"Starting detection process: {os.getpid()}")
|
||||||
setproctitle(f"frigate.detector.{name}")
|
setproctitle(f"frigate.detector.{name}")
|
||||||
listen()
|
listen()
|
||||||
|
|
||||||
stop_event = mp.Event()
|
stop_event: MpEvent = mp.Event()
|
||||||
|
|
||||||
def receiveSignal(signalNumber, frame):
|
def receiveSignal(signalNumber, frame):
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
@ -98,7 +118,6 @@ def prepare_detector(name, detector_config, out_events):
|
|||||||
signal.signal(signal.SIGINT, receiveSignal)
|
signal.signal(signal.SIGINT, receiveSignal)
|
||||||
|
|
||||||
frame_manager = SharedMemoryFrameManager()
|
frame_manager = SharedMemoryFrameManager()
|
||||||
object_detector = LocalObjectDetector(detector_config=detector_config)
|
|
||||||
|
|
||||||
outputs = {}
|
outputs = {}
|
||||||
for name in out_events.keys():
|
for name in out_events.keys():
|
||||||
@ -106,22 +125,24 @@ def prepare_detector(name, detector_config, out_events):
|
|||||||
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
|
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
|
||||||
outputs[name] = {"shm": out_shm, "np": out_np}
|
outputs[name] = {"shm": out_shm, "np": out_np}
|
||||||
|
|
||||||
return stop_event, frame_manager, object_detector, outputs, logger
|
return stop_event, frame_manager, outputs, logger
|
||||||
|
|
||||||
|
|
||||||
def run_detector(
|
def run_detector(
|
||||||
name: str,
|
name: str,
|
||||||
detection_queue: mp.Queue,
|
detection_queue: Queue,
|
||||||
out_events: dict[str, mp.Event],
|
out_events: dict[str, MpEvent],
|
||||||
avg_speed,
|
avg_speed: Value,
|
||||||
start,
|
start: Value,
|
||||||
detector_config,
|
detector_config: BaseDetectorConfig,
|
||||||
):
|
):
|
||||||
|
|
||||||
stop_event, frame_manager, object_detector, outputs, logger = prepare_detector(
|
stop_event, frame_manager, outputs, logger = prepare_detector(
|
||||||
name, detector_config, out_events
|
name, out_events
|
||||||
)
|
)
|
||||||
|
|
||||||
|
object_detector = LocalObjectDetector(detector_config=detector_config)
|
||||||
|
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
connection_id = detection_queue.get(timeout=1)
|
connection_id = detection_queue.get(timeout=1)
|
||||||
@ -152,17 +173,19 @@ def run_detector(
|
|||||||
|
|
||||||
def async_run_detector(
|
def async_run_detector(
|
||||||
name: str,
|
name: str,
|
||||||
detection_queue: mp.Queue,
|
detection_queue: Queue,
|
||||||
out_events: dict[str, mp.Event],
|
out_events: dict[str, MpEvent],
|
||||||
avg_speed,
|
avg_speed: Value,
|
||||||
start,
|
start: Value,
|
||||||
detector_config,
|
detector_config: BaseDetectorConfig,
|
||||||
):
|
):
|
||||||
|
|
||||||
stop_event, frame_manager, object_detector, outputs, logger = prepare_detector(
|
stop_event, frame_manager, outputs, logger = prepare_detector(
|
||||||
name, detector_config, out_events
|
name, out_events
|
||||||
)
|
)
|
||||||
|
|
||||||
|
object_detector = AsyncLocalObjectDetector(detector_config=detector_config)
|
||||||
|
|
||||||
def detect_worker():
|
def detect_worker():
|
||||||
# Continuously fetch frames and send them to the async detector
|
# Continuously fetch frames and send them to the async detector
|
||||||
logger.info("Starting Detect Worker Thread")
|
logger.info("Starting Detect Worker Thread")
|
||||||
@ -184,13 +207,13 @@ def async_run_detector(
|
|||||||
|
|
||||||
# send input to Accelator
|
# send input to Accelator
|
||||||
start.value = datetime.datetime.now().timestamp()
|
start.value = datetime.datetime.now().timestamp()
|
||||||
object_detector.detect_api.send_input(connection_id, input_frame)
|
object_detector.async_send_input(input_frame, connection_id)
|
||||||
|
|
||||||
def result_worker():
|
def result_worker():
|
||||||
# Continuously receive detection results from the async detector
|
# Continuously receive detection results from the async detector
|
||||||
logger.info("Starting Result Worker Thread")
|
logger.info("Starting Result Worker Thread")
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
connection_id, detections = object_detector.detect_api.receive_output()
|
connection_id, detections = object_detector.async_receive_output()
|
||||||
duration = datetime.datetime.now().timestamp() - start.value
|
duration = datetime.datetime.now().timestamp() - start.value
|
||||||
|
|
||||||
frame_manager.close(connection_id)
|
frame_manager.close(connection_id)
|
||||||
@ -222,17 +245,17 @@ def async_run_detector(
|
|||||||
class ObjectDetectProcess:
|
class ObjectDetectProcess:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name,
|
name: str,
|
||||||
detection_queue,
|
detection_queue: Queue,
|
||||||
out_events,
|
out_events: dict[str, MpEvent],
|
||||||
detector_config,
|
detector_config: BaseDetectorConfig,
|
||||||
):
|
):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.out_events = out_events
|
self.out_events = out_events
|
||||||
self.detection_queue = detection_queue
|
self.detection_queue = detection_queue
|
||||||
self.avg_inference_speed = mp.Value("d", 0.01)
|
self.avg_inference_speed = Value("d", 0.01)
|
||||||
self.detection_start = mp.Value("d", 0.0)
|
self.detection_start = Value("d", 0.0)
|
||||||
self.detect_process = None
|
self.detect_process: util.Process | None = None
|
||||||
self.detector_config = detector_config
|
self.detector_config = detector_config
|
||||||
self.start_or_restart()
|
self.start_or_restart()
|
||||||
|
|
||||||
@ -285,7 +308,15 @@ class ObjectDetectProcess:
|
|||||||
|
|
||||||
|
|
||||||
class RemoteObjectDetector:
|
class RemoteObjectDetector:
|
||||||
def __init__(self, name, labels, detection_queue, event, model_config, stop_event):
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
labels: dict[int, str],
|
||||||
|
detection_queue: Queue,
|
||||||
|
event: MpEvent,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
stop_event: MpEvent,
|
||||||
|
):
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
self.name = name
|
self.name = name
|
||||||
self.fps = EventsPerSecond()
|
self.fps = EventsPerSecond()
|
||||||
@ -328,4 +359,4 @@ class RemoteObjectDetector:
|
|||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
self.shm.unlink()
|
self.shm.unlink()
|
||||||
self.out_shm.unlink()
|
self.out_shm.unlink()
|
Loading…
Reference in New Issue
Block a user