Update object_detection.py

This commit is contained in:
abinila siva 2025-04-11 14:22:45 -04:00 committed by Abinila Siva
parent 18ce86ce9c
commit b1562fdabd

View File

@ -121,11 +121,11 @@ class LocalObjectDetector(ObjectDetector):
# Step 6: Return the processed image as a contiguous array of type float32 # Step 6: Return the processed image as a contiguous array of type float32
return np.ascontiguousarray(concatenated_img).astype(np.float32) return np.ascontiguousarray(concatenated_img).astype(np.float32)
# if self.dtype == InputDTypeEnum.float: tensor_input = tensor_input.astype(np.float32) # Convert input to float32
tensor_input = tensor_input.astype(np.float32) tensor_input /= 255.0 # Normalize pixel values to [0, 1]
tensor_input /= 255
tensor_input = tensor_input.transpose(1, 2, 0, 3) # Convert from NHWC to HWNC (expected DFP input shape)
tensor_input = tensor_input.transpose(1,2,0,3) #NHWC --> HWNC(dfp input shape)
return tensor_input return tensor_input
@ -138,21 +138,25 @@ def async_run_detector(
start, start,
detector_config, detector_config,
): ):
# Set thread and process titles for logging and debugging
threading.current_thread().name = f"detector:{name}" threading.current_thread().name = f"detector:{name}"
logger.info(f"Starting MemryX Async detection process: {os.getpid()}") logger.info(f"Starting MemryX Async detection process: {os.getpid()}")
setproctitle(f"frigate.detector.{name}") setproctitle(f"frigate.detector.{name}")
stop_event = mp.Event() stop_event = mp.Event() # Used to gracefully stop threads on signal
def receiveSignal(signalNumber, frame): def receiveSignal(signalNumber, frame):
stop_event.set() stop_event.set()
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, receiveSignal) signal.signal(signal.SIGTERM, receiveSignal)
signal.signal(signal.SIGINT, receiveSignal) signal.signal(signal.SIGINT, receiveSignal)
# Initialize shared memory and detector
frame_manager = SharedMemoryFrameManager() frame_manager = SharedMemoryFrameManager()
object_detector = LocalObjectDetector(detector_config=detector_config) object_detector = LocalObjectDetector(detector_config=detector_config)
# Create shared memory buffers for detector outputs
outputs = {} outputs = {}
for name in out_events.keys(): for name in out_events.keys():
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False) out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False)
@ -160,14 +164,15 @@ def async_run_detector(
outputs[name] = {"shm": out_shm, "np": out_np} outputs[name] = {"shm": out_shm, "np": out_np}
def detect_worker(): def detect_worker():
""" Continuously fetch frames and send them to MemryX """ """Continuously fetch frames and send them to MemryX."""
logger.info(f"Starting Detect Worker Thread") logger.info("Starting Detect Worker Thread")
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
connection_id = detection_queue.get(timeout=1) connection_id = detection_queue.get(timeout=1)
except queue.Empty: except queue.Empty:
continue continue
# Retrieve the input frame from shared memory
input_frame = frame_manager.get( input_frame = frame_manager.get(
connection_id, connection_id,
(1, detector_config.model.height, detector_config.model.width, 3), (1, detector_config.model.height, detector_config.model.width, 3),
@ -176,47 +181,45 @@ def async_run_detector(
if input_frame is None: if input_frame is None:
logger.warning(f"Failed to get frame {connection_id} from SHM") logger.warning(f"Failed to get frame {connection_id} from SHM")
continue continue
# Preprocess and send input to MemryX
input_frame = object_detector.detect_raw_memx(input_frame) input_frame = object_detector.detect_raw_memx(input_frame)
# Start measuring inference time
start.value = datetime.datetime.now().timestamp() start.value = datetime.datetime.now().timestamp()
# Send frame directly to MemryX processing
object_detector.detect_api.send_input(connection_id, input_frame) object_detector.detect_api.send_input(connection_id, input_frame)
def result_worker(): def result_worker():
""" Continuously fetch results from MemryX and update outputs """ """Continuously receive detection results from MemryX."""
logger.info(f"Starting Result Worker Thread") logger.info("Starting Result Worker Thread")
while not stop_event.is_set(): while not stop_event.is_set():
connection_id, detections = object_detector.detect_api.receive_output() connection_id, detections = object_detector.detect_api.receive_output()
# Calculate processing time
duration = datetime.datetime.now().timestamp() - start.value duration = datetime.datetime.now().timestamp() - start.value
frame_manager.close(connection_id) frame_manager.close(connection_id)
# Update average inference speed # Update moving average inference time
avg_speed.value = (avg_speed.value * 9 + duration) / 10 avg_speed.value = (avg_speed.value * 9 + duration) / 10
if connection_id in outputs and detections is not None: if connection_id in outputs and detections is not None:
outputs[connection_id]["np"][:] = detections[:] outputs[connection_id]["np"][:] = detections[:]
out_events[connection_id].set() out_events[connection_id].set()
# Initialize avg_speed # Initialize tracking variables
start.value = 0.0 start.value = 0.0
avg_speed.value = 0.0 # Start with an initial value avg_speed.value = 0.0
# Start worker threads # Start threads for detection input and result output
detect_thread = threading.Thread(target=detect_worker, daemon=True) detect_thread = threading.Thread(target=detect_worker, daemon=True)
result_thread = threading.Thread(target=result_worker, daemon=True) result_thread = threading.Thread(target=result_worker, daemon=True)
detect_thread.start() detect_thread.start()
result_thread.start() result_thread.start()
# Keep the main process alive while threads run
while not stop_event.is_set(): while not stop_event.is_set():
time.sleep(1) # Keep process alive time.sleep(1)
logger.info("Exited MemryX detection process...") logger.info("Exited MemryX detection process...")
def run_detector( def run_detector(
name: str, name: str,
detection_queue: mp.Queue, detection_queue: mp.Queue,
@ -310,7 +313,7 @@ class ObjectDetectProcess:
self.detection_start.value = 0.0 self.detection_start.value = 0.0
if (self.detect_process is not None) and self.detect_process.is_alive(): if (self.detect_process is not None) and self.detect_process.is_alive():
self.stop() self.stop()
if (self.detector_config.type == 'memryx'): if (self.detector_config.type == 'memryx'): # MemryX requires asynchronous detection handling using async_run_detector
self.detect_process = util.Process( self.detect_process = util.Process(
target=async_run_detector, target=async_run_detector,
name=f"detector:{self.name}", name=f"detector:{self.name}",
@ -385,3 +388,4 @@ class RemoteObjectDetector:
def cleanup(self): def cleanup(self):
self.shm.unlink() self.shm.unlink()
self.out_shm.unlink() self.out_shm.unlink()