mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-08-04 13:47:37 +02:00
Update object_detection.py
This commit is contained in:
parent
18ce86ce9c
commit
b1562fdabd
@ -121,11 +121,11 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
# Step 6: Return the processed image as a contiguous array of type float32
|
# Step 6: Return the processed image as a contiguous array of type float32
|
||||||
return np.ascontiguousarray(concatenated_img).astype(np.float32)
|
return np.ascontiguousarray(concatenated_img).astype(np.float32)
|
||||||
|
|
||||||
# if self.dtype == InputDTypeEnum.float:
|
tensor_input = tensor_input.astype(np.float32) # Convert input to float32
|
||||||
tensor_input = tensor_input.astype(np.float32)
|
tensor_input /= 255.0 # Normalize pixel values to [0, 1]
|
||||||
tensor_input /= 255
|
|
||||||
|
tensor_input = tensor_input.transpose(1, 2, 0, 3) # Convert from NHWC to HWNC (expected DFP input shape)
|
||||||
|
|
||||||
tensor_input = tensor_input.transpose(1,2,0,3) #NHWC --> HWNC(dfp input shape)
|
|
||||||
|
|
||||||
return tensor_input
|
return tensor_input
|
||||||
|
|
||||||
@ -138,21 +138,25 @@ def async_run_detector(
|
|||||||
start,
|
start,
|
||||||
detector_config,
|
detector_config,
|
||||||
):
|
):
|
||||||
|
# Set thread and process titles for logging and debugging
|
||||||
threading.current_thread().name = f"detector:{name}"
|
threading.current_thread().name = f"detector:{name}"
|
||||||
logger.info(f"Starting MemryX Async detection process: {os.getpid()}")
|
logger.info(f"Starting MemryX Async detection process: {os.getpid()}")
|
||||||
setproctitle(f"frigate.detector.{name}")
|
setproctitle(f"frigate.detector.{name}")
|
||||||
|
|
||||||
stop_event = mp.Event()
|
stop_event = mp.Event() # Used to gracefully stop threads on signal
|
||||||
|
|
||||||
def receiveSignal(signalNumber, frame):
|
def receiveSignal(signalNumber, frame):
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
|
|
||||||
|
# Register signal handlers for graceful shutdown
|
||||||
signal.signal(signal.SIGTERM, receiveSignal)
|
signal.signal(signal.SIGTERM, receiveSignal)
|
||||||
signal.signal(signal.SIGINT, receiveSignal)
|
signal.signal(signal.SIGINT, receiveSignal)
|
||||||
|
|
||||||
|
# Initialize shared memory and detector
|
||||||
frame_manager = SharedMemoryFrameManager()
|
frame_manager = SharedMemoryFrameManager()
|
||||||
object_detector = LocalObjectDetector(detector_config=detector_config)
|
object_detector = LocalObjectDetector(detector_config=detector_config)
|
||||||
|
|
||||||
|
# Create shared memory buffers for detector outputs
|
||||||
outputs = {}
|
outputs = {}
|
||||||
for name in out_events.keys():
|
for name in out_events.keys():
|
||||||
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False)
|
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False)
|
||||||
@ -160,14 +164,15 @@ def async_run_detector(
|
|||||||
outputs[name] = {"shm": out_shm, "np": out_np}
|
outputs[name] = {"shm": out_shm, "np": out_np}
|
||||||
|
|
||||||
def detect_worker():
|
def detect_worker():
|
||||||
""" Continuously fetch frames and send them to MemryX """
|
"""Continuously fetch frames and send them to MemryX."""
|
||||||
logger.info(f"Starting Detect Worker Thread")
|
logger.info("Starting Detect Worker Thread")
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
connection_id = detection_queue.get(timeout=1)
|
connection_id = detection_queue.get(timeout=1)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Retrieve the input frame from shared memory
|
||||||
input_frame = frame_manager.get(
|
input_frame = frame_manager.get(
|
||||||
connection_id,
|
connection_id,
|
||||||
(1, detector_config.model.height, detector_config.model.width, 3),
|
(1, detector_config.model.height, detector_config.model.width, 3),
|
||||||
@ -176,47 +181,45 @@ def async_run_detector(
|
|||||||
if input_frame is None:
|
if input_frame is None:
|
||||||
logger.warning(f"Failed to get frame {connection_id} from SHM")
|
logger.warning(f"Failed to get frame {connection_id} from SHM")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Preprocess and send input to MemryX
|
||||||
input_frame = object_detector.detect_raw_memx(input_frame)
|
input_frame = object_detector.detect_raw_memx(input_frame)
|
||||||
|
|
||||||
# Start measuring inference time
|
|
||||||
start.value = datetime.datetime.now().timestamp()
|
start.value = datetime.datetime.now().timestamp()
|
||||||
|
|
||||||
# Send frame directly to MemryX processing
|
|
||||||
object_detector.detect_api.send_input(connection_id, input_frame)
|
object_detector.detect_api.send_input(connection_id, input_frame)
|
||||||
|
|
||||||
def result_worker():
|
def result_worker():
|
||||||
""" Continuously fetch results from MemryX and update outputs """
|
"""Continuously receive detection results from MemryX."""
|
||||||
logger.info(f"Starting Result Worker Thread")
|
logger.info("Starting Result Worker Thread")
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
connection_id, detections = object_detector.detect_api.receive_output()
|
connection_id, detections = object_detector.detect_api.receive_output()
|
||||||
|
|
||||||
# Calculate processing time
|
|
||||||
duration = datetime.datetime.now().timestamp() - start.value
|
duration = datetime.datetime.now().timestamp() - start.value
|
||||||
|
|
||||||
frame_manager.close(connection_id)
|
frame_manager.close(connection_id)
|
||||||
|
|
||||||
# Update average inference speed
|
# Update moving average inference time
|
||||||
avg_speed.value = (avg_speed.value * 9 + duration) / 10
|
avg_speed.value = (avg_speed.value * 9 + duration) / 10
|
||||||
|
|
||||||
if connection_id in outputs and detections is not None:
|
if connection_id in outputs and detections is not None:
|
||||||
outputs[connection_id]["np"][:] = detections[:]
|
outputs[connection_id]["np"][:] = detections[:]
|
||||||
out_events[connection_id].set()
|
out_events[connection_id].set()
|
||||||
|
|
||||||
# Initialize avg_speed
|
# Initialize tracking variables
|
||||||
start.value = 0.0
|
start.value = 0.0
|
||||||
avg_speed.value = 0.0 # Start with an initial value
|
avg_speed.value = 0.0
|
||||||
|
|
||||||
# Start worker threads
|
# Start threads for detection input and result output
|
||||||
detect_thread = threading.Thread(target=detect_worker, daemon=True)
|
detect_thread = threading.Thread(target=detect_worker, daemon=True)
|
||||||
result_thread = threading.Thread(target=result_worker, daemon=True)
|
result_thread = threading.Thread(target=result_worker, daemon=True)
|
||||||
detect_thread.start()
|
detect_thread.start()
|
||||||
result_thread.start()
|
result_thread.start()
|
||||||
|
|
||||||
|
# Keep the main process alive while threads run
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
time.sleep(1) # Keep process alive
|
time.sleep(1)
|
||||||
|
|
||||||
logger.info("Exited MemryX detection process...")
|
logger.info("Exited MemryX detection process...")
|
||||||
|
|
||||||
|
|
||||||
def run_detector(
|
def run_detector(
|
||||||
name: str,
|
name: str,
|
||||||
detection_queue: mp.Queue,
|
detection_queue: mp.Queue,
|
||||||
@ -310,7 +313,7 @@ class ObjectDetectProcess:
|
|||||||
self.detection_start.value = 0.0
|
self.detection_start.value = 0.0
|
||||||
if (self.detect_process is not None) and self.detect_process.is_alive():
|
if (self.detect_process is not None) and self.detect_process.is_alive():
|
||||||
self.stop()
|
self.stop()
|
||||||
if (self.detector_config.type == 'memryx'):
|
if (self.detector_config.type == 'memryx'): # MemryX requires asynchronous detection handling using async_run_detector
|
||||||
self.detect_process = util.Process(
|
self.detect_process = util.Process(
|
||||||
target=async_run_detector,
|
target=async_run_detector,
|
||||||
name=f"detector:{self.name}",
|
name=f"detector:{self.name}",
|
||||||
@ -385,3 +388,4 @@ class RemoteObjectDetector:
|
|||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
self.shm.unlink()
|
self.shm.unlink()
|
||||||
self.out_shm.unlink()
|
self.out_shm.unlink()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user