Enrichments: Allow targeting a specific GPU ID (#19342)

This commit is contained in:
baudneo
2025-08-18 17:43:53 -06:00
committed by GitHub
parent 83e9ae616a
commit 33f3ea3b59
11 changed files with 43 additions and 19 deletions

View File

@@ -130,6 +130,11 @@ class SemanticSearchConfig(FrigateBaseModel):
model_size: str = Field(
default="small", title="The size of the embeddings model used."
)
device: Optional[str] = Field(
default=None,
title="The device key to use for semantic search.",
description="This is an override, to target a specific device. See https://onnxruntime.ai/docs/execution-providers/ for more information",
)
class TriggerConfig(FrigateBaseModel):
@@ -196,6 +201,11 @@ class FaceRecognitionConfig(FrigateBaseModel):
blur_confidence_filter: bool = Field(
default=True, title="Apply blur quality filter to face confidence."
)
device: Optional[str] = Field(
default=None,
title="The device key to use for face recognition.",
description="This is an override, to target a specific device. See https://onnxruntime.ai/docs/execution-providers/ for more information",
)
class CameraFaceRecognitionConfig(FrigateBaseModel):
@@ -209,10 +219,6 @@ class CameraFaceRecognitionConfig(FrigateBaseModel):
class LicensePlateRecognitionConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable license plate recognition.")
device: Optional[EnrichmentsDeviceEnum] = Field(
default=EnrichmentsDeviceEnum.CPU,
title="The device used for license plate recognition.",
)
model_size: str = Field(
default="small", title="The size of the embeddings model used."
)
@@ -258,6 +264,11 @@ class LicensePlateRecognitionConfig(FrigateBaseModel):
default=False,
title="Save plates captured for LPR for debugging purposes.",
)
device: Optional[str] = Field(
default=None,
title="The device key to use for LPR.",
description="This is an override, to target a specific device. See https://onnxruntime.ai/docs/execution-providers/ for more information",
)
class CameraLicensePlateRecognitionConfig(FrigateBaseModel):

View File

@@ -269,7 +269,7 @@ class ArcFaceRecognizer(FaceRecognizer):
def __init__(self, config: FrigateConfig):
super().__init__(config)
self.mean_embs: dict[int, np.ndarray] = {}
self.face_embedder: ArcfaceEmbedding = ArcfaceEmbedding()
self.face_embedder: ArcfaceEmbedding = ArcfaceEmbedding(config.face_recognition)
self.model_builder_queue: queue.Queue | None = None
def clear(self) -> None:

View File

@@ -171,7 +171,7 @@ class FaceRealTimeProcessor(RealTimeProcessorApi):
# don't run for non person objects
if obj_data.get("label") != "person":
logger.debug("Not a processing face for non person object.")
logger.debug("Not processing face for a non person object.")
return
# don't overwrite sub label for objects that have a sub label

View File

@@ -112,9 +112,8 @@ class Embeddings:
self.embedding = JinaV2Embedding(
model_size=self.config.semantic_search.model_size,
requestor=self.requestor,
device="GPU"
if self.config.semantic_search.model_size == "large"
else "CPU",
device=config.semantic_search.device
or ("GPU" if config.semantic_search.model_size == "large" else "CPU"),
)
self.text_embedding = lambda input_data: self.embedding(
input_data, embedding_type="text"
@@ -131,7 +130,8 @@ class Embeddings:
self.vision_embedding = JinaV1ImageEmbedding(
model_size=config.semantic_search.model_size,
requestor=self.requestor,
device="GPU" if config.semantic_search.model_size == "large" else "CPU",
device=config.semantic_search.device
or ("GPU" if config.semantic_search.model_size == "large" else "CPU"),
)
def update_stats(self) -> None:

View File

@@ -9,6 +9,7 @@ from frigate.const import MODEL_CACHE_DIR
from frigate.log import redirect_output_to_logger
from frigate.util.downloader import ModelDownloader
from ...config import FaceRecognitionConfig
from .base_embedding import BaseEmbedding
from .runner import ONNXModelRunner
@@ -111,7 +112,7 @@ class FaceNetEmbedding(BaseEmbedding):
class ArcfaceEmbedding(BaseEmbedding):
def __init__(self):
def __init__(self, config: FaceRecognitionConfig):
super().__init__(
model_name="facedet",
model_file="arcface.onnx",
@@ -119,6 +120,7 @@ class ArcfaceEmbedding(BaseEmbedding):
"arcface.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/arcface.onnx",
},
)
self.config = config
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.tokenizer = None
self.feature_extractor = None
@@ -148,7 +150,7 @@ class ArcfaceEmbedding(BaseEmbedding):
self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file),
"GPU",
device=self.config.device or "GPU",
)
def _preprocess_inputs(self, raw_inputs):

View File

@@ -128,7 +128,6 @@ class JinaV1TextEmbedding(BaseEmbedding):
self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file),
self.device,
self.model_size,
)
def _preprocess_inputs(self, raw_inputs):
@@ -207,7 +206,6 @@ class JinaV1ImageEmbedding(BaseEmbedding):
self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file),
self.device,
self.model_size,
)
def _preprocess_inputs(self, raw_inputs):

View File

@@ -128,7 +128,6 @@ class JinaV2Embedding(BaseEmbedding):
self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file),
self.device,
self.model_size,
)
def _preprocess_image(self, image_data: bytes | Image.Image) -> np.ndarray: