Update object_detection.py

This commit is contained in:
abinila siva 2025-04-11 14:22:45 -04:00 committed by GitHub
parent ed104ba1a7
commit fa84512e1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -121,11 +121,11 @@ class LocalObjectDetector(ObjectDetector):
# Step 6: Return the processed image as a contiguous array of type float32
return np.ascontiguousarray(concatenated_img).astype(np.float32)
# if self.dtype == InputDTypeEnum.float:
tensor_input = tensor_input.astype(np.float32)
tensor_input /= 255
tensor_input = tensor_input.astype(np.float32) # Convert input to float32
tensor_input /= 255.0 # Normalize pixel values to [0, 1]
tensor_input = tensor_input.transpose(1, 2, 0, 3) # Convert from NHWC to HWNC (expected DFP input shape)
tensor_input = tensor_input.transpose(1,2,0,3) #NHWC --> HWNC(dfp input shape)
return tensor_input
@ -138,21 +138,25 @@ def async_run_detector(
start,
detector_config,
):
# Set thread and process titles for logging and debugging
threading.current_thread().name = f"detector:{name}"
logger.info(f"Starting MemryX Async detection process: {os.getpid()}")
setproctitle(f"frigate.detector.{name}")
stop_event = mp.Event()
stop_event = mp.Event() # Used to gracefully stop threads on signal
def receiveSignal(signalNumber, frame):
stop_event.set()
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, receiveSignal)
signal.signal(signal.SIGINT, receiveSignal)
# Initialize shared memory and detector
frame_manager = SharedMemoryFrameManager()
object_detector = LocalObjectDetector(detector_config=detector_config)
# Create shared memory buffers for detector outputs
outputs = {}
for name in out_events.keys():
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False)
@ -160,14 +164,15 @@ def async_run_detector(
outputs[name] = {"shm": out_shm, "np": out_np}
def detect_worker():
""" Continuously fetch frames and send them to MemryX """
logger.info(f"Starting Detect Worker Thread")
"""Continuously fetch frames and send them to MemryX."""
logger.info("Starting Detect Worker Thread")
while not stop_event.is_set():
try:
connection_id = detection_queue.get(timeout=1)
except queue.Empty:
continue
# Retrieve the input frame from shared memory
input_frame = frame_manager.get(
connection_id,
(1, detector_config.model.height, detector_config.model.width, 3),
@ -176,47 +181,45 @@ def async_run_detector(
if input_frame is None:
logger.warning(f"Failed to get frame {connection_id} from SHM")
continue
# Preprocess and send input to MemryX
input_frame = object_detector.detect_raw_memx(input_frame)
# Start measuring inference time
start.value = datetime.datetime.now().timestamp()
# Send frame directly to MemryX processing
object_detector.detect_api.send_input(connection_id, input_frame)
def result_worker():
""" Continuously fetch results from MemryX and update outputs """
logger.info(f"Starting Result Worker Thread")
"""Continuously receive detection results from MemryX."""
logger.info("Starting Result Worker Thread")
while not stop_event.is_set():
connection_id, detections = object_detector.detect_api.receive_output()
# Calculate processing time
duration = datetime.datetime.now().timestamp() - start.value
frame_manager.close(connection_id)
# Update average inference speed
# Update moving average inference time
avg_speed.value = (avg_speed.value * 9 + duration) / 10
if connection_id in outputs and detections is not None:
outputs[connection_id]["np"][:] = detections[:]
out_events[connection_id].set()
# Initialize avg_speed
# Initialize tracking variables
start.value = 0.0
avg_speed.value = 0.0 # Start with an initial value
avg_speed.value = 0.0
# Start worker threads
# Start threads for detection input and result output
detect_thread = threading.Thread(target=detect_worker, daemon=True)
result_thread = threading.Thread(target=result_worker, daemon=True)
detect_thread.start()
result_thread.start()
# Keep the main process alive while threads run
while not stop_event.is_set():
time.sleep(1) # Keep process alive
time.sleep(1)
logger.info("Exited MemryX detection process...")
def run_detector(
name: str,
detection_queue: mp.Queue,
@ -310,7 +313,7 @@ class ObjectDetectProcess:
self.detection_start.value = 0.0
if (self.detect_process is not None) and self.detect_process.is_alive():
self.stop()
if (self.detector_config.type == 'memryx'):
if (self.detector_config.type == 'memryx'): # MemryX requires asynchronous detection handling using async_run_detector
self.detect_process = util.Process(
target=async_run_detector,
name=f"detector:{self.name}",
@ -385,3 +388,4 @@ class RemoteObjectDetector:
def cleanup(self):
self.shm.unlink()
self.out_shm.unlink()