mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-09-23 17:52:05 +02:00
Get correct ports running
This commit is contained in:
parent
8b78c85bda
commit
de960285f6
@ -31,7 +31,7 @@ class _RouterDealerRunner(threading.Thread):
|
|||||||
frontend.bind(REQ_ROUTER_ENDPOINT)
|
frontend.bind(REQ_ROUTER_ENDPOINT)
|
||||||
|
|
||||||
backend = self.context.socket(zmq.DEALER)
|
backend = self.context.socket(zmq.DEALER)
|
||||||
backend.connect(self.backend_endpoint)
|
backend.bind(self.backend_endpoint)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
zmq.proxy(frontend, backend)
|
zmq.proxy(frontend, backend)
|
||||||
|
838
frigate/detectors/zmq_client.py
Normal file
838
frigate/detectors/zmq_client.py
Normal file
@ -0,0 +1,838 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
ZMQ TCP ONNX Runtime Client
|
||||||
|
|
||||||
|
This client connects to the ZMQ TCP proxy, accepts tensor inputs,
|
||||||
|
runs inference via ONNX Runtime, and returns detection results.
|
||||||
|
|
||||||
|
Protocol:
|
||||||
|
- Receives multipart messages: [header_json_bytes, tensor_bytes]
|
||||||
|
- Header contains shape and dtype information
|
||||||
|
- Runs ONNX inference on the tensor
|
||||||
|
- Returns results in the expected format: [20, 6] float32 array
|
||||||
|
|
||||||
|
Note: Timeouts are normal when Frigate has no motion to detect.
|
||||||
|
The server will continue running and waiting for requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import zmq
|
||||||
|
from model_util import post_process_dfine, post_process_rfdetr, post_process_yolo
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ZmqOnnxWorker(threading.Thread):
|
||||||
|
"""
|
||||||
|
A worker thread that connects a REP socket to the endpoint and processes
|
||||||
|
requests using a shared model session map. This mirrors the single-runner
|
||||||
|
logic, but the ONNX Runtime session is fetched from the shared map, and
|
||||||
|
created on-demand if missing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
worker_id: int,
|
||||||
|
context: zmq.Context,
|
||||||
|
endpoint: str,
|
||||||
|
models_dir: str,
|
||||||
|
model_sessions: Dict[str, ort.InferenceSession],
|
||||||
|
model_lock: threading.Lock,
|
||||||
|
providers: Optional[List[str]],
|
||||||
|
zero_result: np.ndarray,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(name=f"onnx_worker_{worker_id}", daemon=True)
|
||||||
|
self.worker_id = worker_id
|
||||||
|
self.context = context
|
||||||
|
self.endpoint = self._normalize_endpoint(endpoint)
|
||||||
|
self.models_dir = models_dir
|
||||||
|
self.model_sessions = model_sessions
|
||||||
|
self.model_lock = model_lock
|
||||||
|
self.providers = providers
|
||||||
|
self.zero_result = zero_result
|
||||||
|
self.socket: Optional[zmq.Socket] = None
|
||||||
|
|
||||||
|
def _normalize_endpoint(self, endpoint: str) -> str:
|
||||||
|
if endpoint.startswith("tcp://*:"):
|
||||||
|
port = endpoint.split(":", 2)[-1]
|
||||||
|
return f"tcp://127.0.0.1:{port}"
|
||||||
|
return endpoint
|
||||||
|
|
||||||
|
# --- ZMQ helpers ---
|
||||||
|
def _create_socket(self) -> zmq.Socket:
|
||||||
|
sock = self.context.socket(zmq.REP)
|
||||||
|
sock.setsockopt(zmq.RCVTIMEO, 5000)
|
||||||
|
sock.setsockopt(zmq.SNDTIMEO, 5000)
|
||||||
|
sock.setsockopt(zmq.LINGER, 0)
|
||||||
|
sock.connect(self.endpoint)
|
||||||
|
return sock
|
||||||
|
|
||||||
|
def _decode_request(self, frames: List[bytes]) -> Tuple[Optional[np.ndarray], dict]:
|
||||||
|
if len(frames) < 1:
|
||||||
|
raise ValueError(f"Expected at least 1 frame, got {len(frames)}")
|
||||||
|
|
||||||
|
header_bytes = frames[0]
|
||||||
|
header = json.loads(header_bytes.decode("utf-8"))
|
||||||
|
|
||||||
|
if "model_request" in header:
|
||||||
|
return None, header
|
||||||
|
if "model_data" in header:
|
||||||
|
return None, header
|
||||||
|
if len(frames) < 2:
|
||||||
|
raise ValueError(f"Tensor request expected 2 frames, got {len(frames)}")
|
||||||
|
|
||||||
|
tensor_bytes = frames[1]
|
||||||
|
shape = tuple(header.get("shape", []))
|
||||||
|
dtype_str = header.get("dtype", "uint8")
|
||||||
|
|
||||||
|
dtype = np.dtype(dtype_str)
|
||||||
|
tensor = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape)
|
||||||
|
return tensor, header
|
||||||
|
|
||||||
|
def _build_response(self, result: np.ndarray) -> List[bytes]:
|
||||||
|
header = {
|
||||||
|
"shape": list(result.shape),
|
||||||
|
"dtype": str(result.dtype.name),
|
||||||
|
"timestamp": time.time(),
|
||||||
|
}
|
||||||
|
return [json.dumps(header).encode("utf-8"), result.tobytes(order="C")]
|
||||||
|
|
||||||
|
def _build_error_response(self, error_msg: str) -> List[bytes]:
|
||||||
|
error_header = {"shape": [20, 6], "dtype": "float32", "error": error_msg}
|
||||||
|
return [
|
||||||
|
json.dumps(error_header).encode("utf-8"),
|
||||||
|
self.zero_result.tobytes(order="C"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# --- Model/session helpers ---
|
||||||
|
def _check_model_exists(self, model_name: str) -> bool:
|
||||||
|
return os.path.exists(os.path.join(self.models_dir, model_name))
|
||||||
|
|
||||||
|
def _save_model(self, model_name: str, model_data: bytes) -> bool:
|
||||||
|
try:
|
||||||
|
os.makedirs(self.models_dir, exist_ok=True)
|
||||||
|
with open(os.path.join(self.models_dir, model_name), "wb") as f:
|
||||||
|
f.write(model_data)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Worker {self.worker_id} failed to save model {model_name}: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_or_create_session(self, model_name: str) -> Optional[ort.InferenceSession]:
|
||||||
|
with self.model_lock:
|
||||||
|
session = self.model_sessions.get(model_name)
|
||||||
|
if session is not None:
|
||||||
|
return session
|
||||||
|
try:
|
||||||
|
providers = self.providers or ["CoreMLExecutionProvider"]
|
||||||
|
session = ort.InferenceSession(
|
||||||
|
os.path.join(self.models_dir, model_name), providers=providers
|
||||||
|
)
|
||||||
|
self.model_sessions[model_name] = session
|
||||||
|
return session
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Worker {self.worker_id} failed to load model {model_name}: {e}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# --- Inference helpers ---
|
||||||
|
def _extract_input_hw(self, header: dict) -> Tuple[int, int]:
|
||||||
|
try:
|
||||||
|
if "width" in header and "height" in header:
|
||||||
|
return int(header["width"]), int(header["height"])
|
||||||
|
shape = tuple(header.get("shape", []))
|
||||||
|
layout = header.get("layout") or header.get("order")
|
||||||
|
if layout and shape:
|
||||||
|
layout = str(layout).upper()
|
||||||
|
if len(shape) == 4:
|
||||||
|
if layout == "NCHW":
|
||||||
|
return int(shape[3]), int(shape[2])
|
||||||
|
if layout == "NHWC":
|
||||||
|
return int(shape[2]), int(shape[1])
|
||||||
|
if len(shape) == 3:
|
||||||
|
if layout == "CHW":
|
||||||
|
return int(shape[2]), int(shape[1])
|
||||||
|
if layout == "HWC":
|
||||||
|
return int(shape[1]), int(shape[0])
|
||||||
|
if shape:
|
||||||
|
if len(shape) == 4:
|
||||||
|
_, d1, d2, d3 = shape
|
||||||
|
if d1 in (1, 3):
|
||||||
|
return int(d3), int(d2)
|
||||||
|
if d3 in (1, 3):
|
||||||
|
return int(d2), int(d1)
|
||||||
|
return int(d2), int(d1)
|
||||||
|
if len(shape) == 3:
|
||||||
|
d0, d1, d2 = shape
|
||||||
|
if d0 in (1, 3):
|
||||||
|
return int(d2), int(d1)
|
||||||
|
if d2 in (1, 3):
|
||||||
|
return int(d1), int(d0)
|
||||||
|
return int(d1), int(d0)
|
||||||
|
if len(shape) == 2:
|
||||||
|
h, w = shape
|
||||||
|
return int(w), int(h)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return 320, 320
|
||||||
|
|
||||||
|
def _run_inference(
|
||||||
|
self, session: ort.InferenceSession, tensor: np.ndarray, header: dict
|
||||||
|
) -> np.ndarray:
|
||||||
|
try:
|
||||||
|
model_type = header.get("model_type")
|
||||||
|
width, height = self._extract_input_hw(header)
|
||||||
|
|
||||||
|
if model_type == "dfine":
|
||||||
|
input_data = {
|
||||||
|
"images": tensor.astype(np.float32),
|
||||||
|
"orig_target_sizes": np.array([[height, width]], dtype=np.int64),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
input_name = session.get_inputs()[0].name
|
||||||
|
input_data = {input_name: tensor}
|
||||||
|
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
t_start = time.perf_counter()
|
||||||
|
|
||||||
|
outputs = session.run(None, input_data)
|
||||||
|
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
t_after_onnx = time.perf_counter()
|
||||||
|
|
||||||
|
if model_type == "yolo-generic" or model_type == "yologeneric":
|
||||||
|
result = post_process_yolo(outputs, width, height)
|
||||||
|
elif model_type == "dfine":
|
||||||
|
result = post_process_dfine(outputs, width, height)
|
||||||
|
elif model_type == "rfdetr":
|
||||||
|
result = post_process_rfdetr(outputs)
|
||||||
|
else:
|
||||||
|
result = np.zeros((20, 6), dtype=np.float32)
|
||||||
|
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
t_after_post = time.perf_counter()
|
||||||
|
onnx_ms = (t_after_onnx - t_start) * 1000.0
|
||||||
|
post_ms = (t_after_post - t_after_onnx) * 1000.0
|
||||||
|
total_ms = (t_after_post - t_start) * 1000.0
|
||||||
|
logger.debug(
|
||||||
|
f"Worker {self.worker_id} timing: onnx={onnx_ms:.2f}ms, post={post_ms:.2f}ms, total={total_ms:.2f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.astype(np.float32)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Worker {self.worker_id} ONNX inference failed: {e}")
|
||||||
|
return self.zero_result
|
||||||
|
|
||||||
|
# --- Message handlers ---
|
||||||
|
def _handle_model_request(self, header: dict) -> List[bytes]:
|
||||||
|
model_name = header.get("model_name")
|
||||||
|
if not model_name:
|
||||||
|
return self._build_error_response("Model request missing model_name")
|
||||||
|
if self._check_model_exists(model_name):
|
||||||
|
# Ensure session exists
|
||||||
|
if self._get_or_create_session(model_name) is not None:
|
||||||
|
response_header = {
|
||||||
|
"model_available": True,
|
||||||
|
"model_loaded": True,
|
||||||
|
"model_name": model_name,
|
||||||
|
"message": f"Model {model_name} loaded successfully",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
response_header = {
|
||||||
|
"model_available": True,
|
||||||
|
"model_loaded": False,
|
||||||
|
"model_name": model_name,
|
||||||
|
"message": f"Model {model_name} exists but failed to load",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
response_header = {
|
||||||
|
"model_available": False,
|
||||||
|
"model_name": model_name,
|
||||||
|
"message": f"Model {model_name} not found, please send model data",
|
||||||
|
}
|
||||||
|
return [json.dumps(response_header).encode("utf-8")]
|
||||||
|
|
||||||
|
def _handle_model_data(self, header: dict, model_data: bytes) -> List[bytes]:
|
||||||
|
model_name = header.get("model_name")
|
||||||
|
if not model_name:
|
||||||
|
return self._build_error_response("Model data missing model_name")
|
||||||
|
if self._save_model(model_name, model_data):
|
||||||
|
# Ensure session is created
|
||||||
|
if self._get_or_create_session(model_name) is not None:
|
||||||
|
response_header = {
|
||||||
|
"model_saved": True,
|
||||||
|
"model_loaded": True,
|
||||||
|
"model_name": model_name,
|
||||||
|
"message": f"Model {model_name} saved and loaded successfully",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
response_header = {
|
||||||
|
"model_saved": True,
|
||||||
|
"model_loaded": False,
|
||||||
|
"model_name": model_name,
|
||||||
|
"message": f"Model {model_name} saved but failed to load",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
response_header = {
|
||||||
|
"model_saved": False,
|
||||||
|
"model_loaded": False,
|
||||||
|
"model_name": model_name,
|
||||||
|
"message": f"Failed to save model {model_name}",
|
||||||
|
}
|
||||||
|
return [json.dumps(response_header).encode("utf-8")]
|
||||||
|
|
||||||
|
# --- Thread run ---
|
||||||
|
def run(self) -> None: # pragma: no cover - runtime loop
|
||||||
|
try:
|
||||||
|
self.socket = self._create_socket()
|
||||||
|
logger.info(
|
||||||
|
f"Worker {self.worker_id} connected REP to endpoint: {self.endpoint}"
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
frames = self.socket.recv_multipart()
|
||||||
|
tensor, header = self._decode_request(frames)
|
||||||
|
|
||||||
|
if "model_request" in header:
|
||||||
|
response = self._handle_model_request(header)
|
||||||
|
self.socket.send_multipart(response)
|
||||||
|
continue
|
||||||
|
if "model_data" in header and len(frames) >= 2:
|
||||||
|
model_data = frames[1]
|
||||||
|
response = self._handle_model_data(header, model_data)
|
||||||
|
self.socket.send_multipart(response)
|
||||||
|
continue
|
||||||
|
if tensor is not None:
|
||||||
|
model_name = header.get("model_name")
|
||||||
|
session = None
|
||||||
|
if model_name:
|
||||||
|
session = self._get_or_create_session(model_name)
|
||||||
|
if session is None:
|
||||||
|
result = self.zero_result
|
||||||
|
else:
|
||||||
|
result = self._run_inference(session, tensor, header)
|
||||||
|
self.socket.send_multipart(self._build_response(result))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Unknown message: reply with zeros
|
||||||
|
self.socket.send_multipart(self._build_response(self.zero_result))
|
||||||
|
except zmq.Again:
|
||||||
|
continue
|
||||||
|
except zmq.ZMQError as e:
|
||||||
|
logger.error(f"Worker {self.worker_id} ZMQ error: {e}")
|
||||||
|
# Recreate socket on transient errors
|
||||||
|
try:
|
||||||
|
if self.socket:
|
||||||
|
self.socket.close(linger=0)
|
||||||
|
finally:
|
||||||
|
self.socket = self._create_socket()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Worker {self.worker_id} unexpected error: {e}")
|
||||||
|
try:
|
||||||
|
self.socket.send_multipart(self._build_error_response(str(e)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if self.socket:
|
||||||
|
self.socket.close(linger=0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ZmqOnnxClient:
|
||||||
|
"""
|
||||||
|
ZMQ TCP client that runs ONNX inference on received tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str = "tcp://*:5555",
|
||||||
|
model_path: Optional[str] = "AUTO",
|
||||||
|
providers: Optional[List[str]] = None,
|
||||||
|
session_options: Optional[ort.SessionOptions] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the ZMQ ONNX client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint: ZMQ IPC endpoint to bind to
|
||||||
|
model_path: Path to ONNX model file or "AUTO" for automatic model management
|
||||||
|
providers: ONNX Runtime execution providers
|
||||||
|
session_options: ONNX Runtime session options
|
||||||
|
"""
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.model_path = model_path
|
||||||
|
self.current_model = None
|
||||||
|
self.model_ready = False
|
||||||
|
self.models_dir = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(__file__)), "models"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shared ZMQ context and shared session map across workers
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.model_sessions: Dict[str, ort.InferenceSession] = {}
|
||||||
|
self.model_lock = threading.Lock()
|
||||||
|
self.providers = providers
|
||||||
|
|
||||||
|
# Preallocate zero result for error cases
|
||||||
|
self.zero_result = np.zeros((20, 6), dtype=np.float32)
|
||||||
|
|
||||||
|
logger.info(f"ZMQ ONNX client will start workers on endpoint: {endpoint}")
|
||||||
|
|
||||||
|
def start_server(self, num_workers: int = 4) -> None:
|
||||||
|
workers: list[ZmqOnnxWorker] = []
|
||||||
|
for i in range(num_workers):
|
||||||
|
w = ZmqOnnxWorker(
|
||||||
|
worker_id=i,
|
||||||
|
context=self.context,
|
||||||
|
endpoint=self.endpoint,
|
||||||
|
models_dir=self.models_dir,
|
||||||
|
model_sessions=self.model_sessions,
|
||||||
|
model_lock=self.model_lock,
|
||||||
|
providers=self.providers,
|
||||||
|
zero_result=self.zero_result,
|
||||||
|
)
|
||||||
|
w.start()
|
||||||
|
workers.append(w)
|
||||||
|
logger.info(f"Started {num_workers} ZMQ REP workers on backend {self.endpoint}")
|
||||||
|
try:
|
||||||
|
for w in workers:
|
||||||
|
w.join()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Shutting down workers...")
|
||||||
|
|
||||||
|
def _check_model_exists(self, model_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a model exists in the models directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model file to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if model exists, False otherwise
|
||||||
|
"""
|
||||||
|
model_path = os.path.join(self.models_dir, model_name)
|
||||||
|
return os.path.exists(model_path)
|
||||||
|
|
||||||
|
# These methods remain for compatibility but are unused in worker mode
|
||||||
|
|
||||||
|
def _save_model(self, model_name: str, model_data: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
Save model data to the models directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model file to save
|
||||||
|
model_data: Binary model data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if model saved successfully, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Ensure models directory exists
|
||||||
|
os.makedirs(self.models_dir, exist_ok=True)
|
||||||
|
|
||||||
|
model_path = os.path.join(self.models_dir, model_name)
|
||||||
|
logger.info(f"Saving model to: {model_path}")
|
||||||
|
|
||||||
|
with open(model_path, "wb") as f:
|
||||||
|
f.write(model_data)
|
||||||
|
|
||||||
|
logger.info(f"Model saved successfully: {model_name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save model {model_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _decode_request(self, frames: List[bytes]) -> Tuple[np.ndarray, dict]:
|
||||||
|
"""
|
||||||
|
Decode the incoming request frames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames: List of message frames
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (tensor, header_dict)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if len(frames) < 1:
|
||||||
|
raise ValueError(f"Expected at least 1 frame, got {len(frames)}")
|
||||||
|
|
||||||
|
# Parse header
|
||||||
|
header_bytes = frames[0]
|
||||||
|
header = json.loads(header_bytes.decode("utf-8"))
|
||||||
|
|
||||||
|
if "model_request" in header:
|
||||||
|
return None, header
|
||||||
|
|
||||||
|
if "model_data" in header:
|
||||||
|
if len(frames) < 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model data request expected 2 frames, got {len(frames)}"
|
||||||
|
)
|
||||||
|
return None, header
|
||||||
|
|
||||||
|
if len(frames) < 2:
|
||||||
|
raise ValueError(f"Tensor request expected 2 frames, got {len(frames)}")
|
||||||
|
|
||||||
|
tensor_bytes = frames[1]
|
||||||
|
shape = tuple(header.get("shape", []))
|
||||||
|
dtype_str = header.get("dtype", "uint8")
|
||||||
|
|
||||||
|
dtype = np.dtype(dtype_str)
|
||||||
|
tensor = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape)
|
||||||
|
return tensor, header
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to decode request: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _run_inference(self, tensor: np.ndarray, header: dict) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Run ONNX inference on the input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Input tensor
|
||||||
|
header: Request header containing metadata (e.g., shape, layout)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Detection results as numpy array
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If no ONNX session is available or inference fails
|
||||||
|
"""
|
||||||
|
if self.session is None:
|
||||||
|
logger.warning("No ONNX session available, returning zero results")
|
||||||
|
return self.zero_result
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare input for ONNX Runtime
|
||||||
|
# Determine input spatial size (W, H) from header/shape/layout
|
||||||
|
model_type = header.get("model_type")
|
||||||
|
width, height = self._extract_input_hw(header)
|
||||||
|
|
||||||
|
if model_type == "dfine":
|
||||||
|
# DFine model requires both images and orig_target_sizes inputs
|
||||||
|
input_data = {
|
||||||
|
"images": tensor.astype(np.float32),
|
||||||
|
"orig_target_sizes": np.array([[height, width]], dtype=np.int64),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Other models use single input
|
||||||
|
input_name = self.session.get_inputs()[0].name
|
||||||
|
input_data = {input_name: tensor}
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
t_start = time.perf_counter()
|
||||||
|
|
||||||
|
outputs = self.session.run(None, input_data)
|
||||||
|
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
t_after_onnx = time.perf_counter()
|
||||||
|
|
||||||
|
if model_type == "yolo-generic" or model_type == "yologeneric":
|
||||||
|
result = post_process_yolo(outputs, width, height)
|
||||||
|
elif model_type == "dfine":
|
||||||
|
result = post_process_dfine(outputs, width, height)
|
||||||
|
elif model_type == "rfdetr":
|
||||||
|
result = post_process_rfdetr(outputs)
|
||||||
|
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
t_after_post = time.perf_counter()
|
||||||
|
onnx_ms = (t_after_onnx - t_start) * 1000.0
|
||||||
|
post_ms = (t_after_post - t_after_onnx) * 1000.0
|
||||||
|
total_ms = (t_after_post - t_start) * 1000.0
|
||||||
|
logger.debug(
|
||||||
|
f"Inference timing: onnx={onnx_ms:.2f}ms, post={post_ms:.2f}ms, total={total_ms:.2f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure float32 dtype
|
||||||
|
result = result.astype(np.float32)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ONNX inference failed: {e}")
|
||||||
|
return self.zero_result
|
||||||
|
|
||||||
|
def _extract_input_hw(self, header: dict) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Extract (width, height) from the header and/or tensor shape, supporting
|
||||||
|
NHWC/NCHW as well as 3D/4D inputs. Falls back to 320x320 if unknown.
|
||||||
|
|
||||||
|
Preference order:
|
||||||
|
1) Explicit header keys: width/height
|
||||||
|
2) Use provided layout to interpret shape
|
||||||
|
3) Heuristics on shape
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if "width" in header and "height" in header:
|
||||||
|
return int(header["width"]), int(header["height"])
|
||||||
|
|
||||||
|
shape = tuple(header.get("shape", []))
|
||||||
|
layout = header.get("layout") or header.get("order")
|
||||||
|
|
||||||
|
if layout and shape:
|
||||||
|
layout = str(layout).upper()
|
||||||
|
if len(shape) == 4:
|
||||||
|
if layout == "NCHW":
|
||||||
|
return int(shape[3]), int(shape[2])
|
||||||
|
if layout == "NHWC":
|
||||||
|
return int(shape[2]), int(shape[1])
|
||||||
|
if len(shape) == 3:
|
||||||
|
if layout == "CHW":
|
||||||
|
return int(shape[2]), int(shape[1])
|
||||||
|
if layout == "HWC":
|
||||||
|
return int(shape[1]), int(shape[0])
|
||||||
|
|
||||||
|
if shape:
|
||||||
|
if len(shape) == 4:
|
||||||
|
_, d1, d2, d3 = shape
|
||||||
|
if d1 in (1, 3):
|
||||||
|
return int(d3), int(d2)
|
||||||
|
if d3 in (1, 3):
|
||||||
|
return int(d2), int(d1)
|
||||||
|
return int(d2), int(d1)
|
||||||
|
if len(shape) == 3:
|
||||||
|
d0, d1, d2 = shape
|
||||||
|
if d0 in (1, 3):
|
||||||
|
return int(d2), int(d1)
|
||||||
|
if d2 in (1, 3):
|
||||||
|
return int(d1), int(d0)
|
||||||
|
return int(d1), int(d0)
|
||||||
|
if len(shape) == 2:
|
||||||
|
h, w = shape
|
||||||
|
return int(w), int(h)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to extract input size from header: {e}")
|
||||||
|
|
||||||
|
logger.debug("Falling back to default input size (320x320)")
|
||||||
|
return 320, 320
|
||||||
|
|
||||||
|
def _build_response(self, result: np.ndarray) -> List[bytes]:
|
||||||
|
"""
|
||||||
|
Build the response message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Detection results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of response frames
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Build header
|
||||||
|
header = {
|
||||||
|
"shape": list(result.shape),
|
||||||
|
"dtype": str(result.dtype.name),
|
||||||
|
"timestamp": time.time(),
|
||||||
|
}
|
||||||
|
header_bytes = json.dumps(header).encode("utf-8")
|
||||||
|
|
||||||
|
# Convert result to bytes
|
||||||
|
result_bytes = result.tobytes(order="C")
|
||||||
|
|
||||||
|
return [header_bytes, result_bytes]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to build response: {e}")
|
||||||
|
# Return zero result as fallback
|
||||||
|
header = {
|
||||||
|
"shape": [20, 6],
|
||||||
|
"dtype": "float32",
|
||||||
|
"error": "Failed to build response",
|
||||||
|
}
|
||||||
|
header_bytes = json.dumps(header).encode("utf-8")
|
||||||
|
result_bytes = self.zero_result.tobytes(order="C")
|
||||||
|
return [header_bytes, result_bytes]
|
||||||
|
|
||||||
|
def _handle_model_request(self, header: dict) -> List[bytes]:
|
||||||
|
"""
|
||||||
|
Handle model availability request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
header: Request header containing model information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response message indicating model availability
|
||||||
|
"""
|
||||||
|
model_name = header.get("model_name")
|
||||||
|
|
||||||
|
if not model_name:
|
||||||
|
logger.error("Model request missing model_name")
|
||||||
|
return self._build_error_response("Model request missing model_name")
|
||||||
|
|
||||||
|
logger.info(f"Model availability request for: {model_name}")
|
||||||
|
|
||||||
|
if self._check_model_exists(model_name):
|
||||||
|
logger.info(f"Model {model_name} exists locally")
|
||||||
|
# Try to load the model
|
||||||
|
if self._load_model(model_name):
|
||||||
|
response_header = {
|
||||||
|
"model_available": True,
|
||||||
|
"model_loaded": True,
|
||||||
|
"model_name": model_name,
|
||||||
|
"message": f"Model {model_name} loaded successfully",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
response_header = {
|
||||||
|
"model_available": True,
|
||||||
|
"model_loaded": False,
|
||||||
|
"model_name": model_name,
|
||||||
|
"message": f"Model {model_name} exists but failed to load",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.info(f"Model {model_name} not found, requesting transfer")
|
||||||
|
response_header = {
|
||||||
|
"model_available": False,
|
||||||
|
"model_name": model_name,
|
||||||
|
"message": f"Model {model_name} not found, please send model data",
|
||||||
|
}
|
||||||
|
|
||||||
|
return [json.dumps(response_header).encode("utf-8")]
|
||||||
|
|
||||||
|
def _handle_model_data(self, header: dict, model_data: bytes) -> List[bytes]:
|
||||||
|
"""
|
||||||
|
Handle model data transfer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
header: Request header containing model information
|
||||||
|
model_data: Binary model data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response message indicating save success/failure
|
||||||
|
"""
|
||||||
|
model_name = header.get("model_name")
|
||||||
|
|
||||||
|
if not model_name:
|
||||||
|
logger.error("Model data missing model_name")
|
||||||
|
return self._build_error_response("Model data missing model_name")
|
||||||
|
|
||||||
|
logger.info(f"Received model data for: {model_name}")
|
||||||
|
|
||||||
|
if self._save_model(model_name, model_data):
|
||||||
|
# Try to load the model
|
||||||
|
if self._load_model(model_name):
|
||||||
|
response_header = {
|
||||||
|
"model_saved": True,
|
||||||
|
"model_loaded": True,
|
||||||
|
"model_name": model_name,
|
||||||
|
"message": f"Model {model_name} saved and loaded successfully",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
response_header = {
|
||||||
|
"model_saved": True,
|
||||||
|
"model_loaded": False,
|
||||||
|
"model_name": model_name,
|
||||||
|
"message": f"Model {model_name} saved but failed to load",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
response_header = {
|
||||||
|
"model_saved": False,
|
||||||
|
"model_loaded": False,
|
||||||
|
"model_name": model_name,
|
||||||
|
"message": f"Failed to save model {model_name}",
|
||||||
|
}
|
||||||
|
|
||||||
|
return [json.dumps(response_header).encode("utf-8")]
|
||||||
|
|
||||||
|
def _build_error_response(self, error_msg: str) -> List[bytes]:
|
||||||
|
"""Build an error response message."""
|
||||||
|
error_header = {"error": error_msg}
|
||||||
|
return [json.dumps(error_header).encode("utf-8")]
|
||||||
|
|
||||||
|
# Removed legacy single-thread start_server implementation in favor of worker pool
|
||||||
|
|
||||||
|
def _send_error_response(self, error_msg: str):
|
||||||
|
"""Send an error response to the client."""
|
||||||
|
try:
|
||||||
|
error_header = {"shape": [20, 6], "dtype": "float32", "error": error_msg}
|
||||||
|
error_response = [
|
||||||
|
json.dumps(error_header).encode("utf-8"),
|
||||||
|
self.zero_result.tobytes(order="C"),
|
||||||
|
]
|
||||||
|
self.socket.send_multipart(error_response)
|
||||||
|
except Exception as send_error:
|
||||||
|
logger.error(f"Failed to send error response: {send_error}")
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Clean up resources."""
|
||||||
|
try:
|
||||||
|
if self.socket:
|
||||||
|
self.socket.close()
|
||||||
|
self.socket = None
|
||||||
|
if self.context:
|
||||||
|
self.context.term()
|
||||||
|
self.context = None
|
||||||
|
logger.info("Cleanup completed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cleanup error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to run the ZMQ ONNX client."""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="ZMQ TCP ONNX Runtime Client")
|
||||||
|
parser.add_argument(
|
||||||
|
"--endpoint",
|
||||||
|
default="tcp://*:5555",
|
||||||
|
help="ZMQ TCP endpoint (default: tcp://*:5555)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
default="AUTO",
|
||||||
|
help="Path to ONNX model file or AUTO for automatic model management",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--providers",
|
||||||
|
nargs="+",
|
||||||
|
default=["CoreMLExecutionProvider"],
|
||||||
|
help="ONNX Runtime execution providers",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--workers",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of REP worker threads",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose", "-v", action="store_true", help="Enable verbose logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.verbose:
|
||||||
|
logging.getLogger().setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
# Create and start client
|
||||||
|
client = ZmqOnnxClient(
|
||||||
|
endpoint=args.endpoint, model_path=args.model, providers=args.providers
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
client.start_server(num_workers=args.workers)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Interrupted by user")
|
||||||
|
finally:
|
||||||
|
client.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user