mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-08-04 13:47:37 +02:00
Update object_detection.py
This commit is contained in:
parent
9930b2be91
commit
b42ac4efdc
@ -1,4 +1,5 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import time
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
@ -8,8 +9,6 @@ import threading
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
|
||||||
import time
|
|
||||||
from setproctitle import setproctitle
|
from setproctitle import setproctitle
|
||||||
|
|
||||||
import frigate.util as util
|
import frigate.util as util
|
||||||
@ -18,7 +17,6 @@ from frigate.detectors.detector_config import (
|
|||||||
BaseDetectorConfig,
|
BaseDetectorConfig,
|
||||||
InputDTypeEnum,
|
InputDTypeEnum,
|
||||||
InputTensorEnum,
|
InputTensorEnum,
|
||||||
ModelTypeEnum
|
|
||||||
)
|
)
|
||||||
from frigate.util.builtin import EventsPerSecond, load_labels
|
from frigate.util.builtin import EventsPerSecond, load_labels
|
||||||
from frigate.util.image import SharedMemoryFrameManager, UntrackedSharedMemory
|
from frigate.util.image import SharedMemoryFrameManager, UntrackedSharedMemory
|
||||||
@ -52,8 +50,6 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
self.labels = {}
|
self.labels = {}
|
||||||
else:
|
else:
|
||||||
self.labels = load_labels(labels)
|
self.labels = load_labels(labels)
|
||||||
|
|
||||||
self.model_type = detector_config.model.model_type
|
|
||||||
|
|
||||||
if detector_config:
|
if detector_config:
|
||||||
self.input_transform = tensor_transform(detector_config.model.input_tensor)
|
self.input_transform = tensor_transform(detector_config.model.input_tensor)
|
||||||
@ -91,42 +87,6 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
tensor_input /= 255
|
tensor_input /= 255
|
||||||
|
|
||||||
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
||||||
|
|
||||||
def detect_raw_memx(self, tensor_input: np.ndarray):
|
|
||||||
|
|
||||||
if self.model_type == ModelTypeEnum.yolox:
|
|
||||||
|
|
||||||
tensor_input = tensor_input.squeeze(0)
|
|
||||||
|
|
||||||
padded_img = np.ones((640, 640, 3),
|
|
||||||
dtype=np.uint8) * 114
|
|
||||||
|
|
||||||
scale = min(640 / float(tensor_input.shape[0]),
|
|
||||||
640 / float(tensor_input.shape[1]))
|
|
||||||
sx,sy = int(tensor_input.shape[1] * scale), int(tensor_input.shape[0] * scale)
|
|
||||||
|
|
||||||
resized_img = cv2.resize(tensor_input, (sx,sy), interpolation=cv2.INTER_LINEAR)
|
|
||||||
padded_img[:sy, :sx] = resized_img.astype(np.uint8)
|
|
||||||
|
|
||||||
|
|
||||||
# Step 4: Slice the padded image into 4 quadrants and concatenate them into 12 channels
|
|
||||||
x0 = padded_img[0::2, 0::2, :] # Top-left
|
|
||||||
x1 = padded_img[1::2, 0::2, :] # Bottom-left
|
|
||||||
x2 = padded_img[0::2, 1::2, :] # Top-right
|
|
||||||
x3 = padded_img[1::2, 1::2, :] # Bottom-right
|
|
||||||
|
|
||||||
# Step 5: Concatenate along the channel dimension (axis 2)
|
|
||||||
concatenated_img = np.concatenate([x0, x1, x2, x3], axis=2)
|
|
||||||
|
|
||||||
# Step 6: Return the processed image as a contiguous array of type float32
|
|
||||||
return np.ascontiguousarray(concatenated_img).astype(np.float32)
|
|
||||||
|
|
||||||
tensor_input = tensor_input.astype(np.float32) # Convert input to float32
|
|
||||||
tensor_input /= 255.0 # Normalize pixel values to [0, 1]
|
|
||||||
|
|
||||||
tensor_input = tensor_input.transpose(1, 2, 0, 3) # Convert from NHWC to HWNC (expected DFP input shape)
|
|
||||||
|
|
||||||
return tensor_input
|
|
||||||
|
|
||||||
|
|
||||||
def run_detector(
|
def run_detector(
|
||||||
@ -186,6 +146,8 @@ def run_detector(
|
|||||||
avg_speed.value = (avg_speed.value * 9 + duration) / 10
|
avg_speed.value = (avg_speed.value * 9 + duration) / 10
|
||||||
|
|
||||||
logger.info("Exited detection process...")
|
logger.info("Exited detection process...")
|
||||||
|
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
||||||
|
|
||||||
|
|
||||||
def async_run_detector(
|
def async_run_detector(
|
||||||
name: str,
|
name: str,
|
||||||
@ -197,7 +159,7 @@ def async_run_detector(
|
|||||||
):
|
):
|
||||||
# Set thread and process titles for logging and debugging
|
# Set thread and process titles for logging and debugging
|
||||||
threading.current_thread().name = f"detector:{name}"
|
threading.current_thread().name = f"detector:{name}"
|
||||||
logger.info(f"Starting async detection process: {os.getpid()}")
|
logger.info(f"Starting detection process: {os.getpid()}")
|
||||||
setproctitle(f"frigate.detector.{name}")
|
setproctitle(f"frigate.detector.{name}")
|
||||||
|
|
||||||
stop_event = mp.Event() # Used to gracefully stop threads on signal
|
stop_event = mp.Event() # Used to gracefully stop threads on signal
|
||||||
@ -221,7 +183,7 @@ def async_run_detector(
|
|||||||
outputs[name] = {"shm": out_shm, "np": out_np}
|
outputs[name] = {"shm": out_shm, "np": out_np}
|
||||||
|
|
||||||
def detect_worker():
|
def detect_worker():
|
||||||
"""Continuously fetch frames and send them to MemryX."""
|
# """Continuously fetch frames and send them to the detector accelerator."""
|
||||||
logger.info("Starting Detect Worker Thread")
|
logger.info("Starting Detect Worker Thread")
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
@ -239,13 +201,12 @@ def async_run_detector(
|
|||||||
logger.warning(f"Failed to get frame {connection_id} from SHM")
|
logger.warning(f"Failed to get frame {connection_id} from SHM")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Preprocess and send input to MemryX
|
#send input to Accelator
|
||||||
input_frame = object_detector.detect_raw_memx(input_frame)
|
|
||||||
start.value = datetime.datetime.now().timestamp()
|
start.value = datetime.datetime.now().timestamp()
|
||||||
object_detector.detect_api.send_input(connection_id, input_frame)
|
object_detector.detect_api.send_input(connection_id, input_frame)
|
||||||
|
|
||||||
def result_worker():
|
def result_worker():
|
||||||
"""Continuously receive detection results from MemryX."""
|
# """Continuously receive detection results from detector accelerator."""
|
||||||
logger.info("Starting Result Worker Thread")
|
logger.info("Starting Result Worker Thread")
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
connection_id, detections = object_detector.detect_api.receive_output()
|
connection_id, detections = object_detector.detect_api.receive_output()
|
||||||
@ -274,7 +235,8 @@ def async_run_detector(
|
|||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
logger.info("Exited async detection process...")
|
logger.info("Exited detection process...")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectDetectProcess:
|
class ObjectDetectProcess:
|
||||||
@ -386,4 +348,3 @@ class RemoteObjectDetector:
|
|||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
self.shm.unlink()
|
self.shm.unlink()
|
||||||
self.out_shm.unlink()
|
self.out_shm.unlink()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user