Enrichments: Allow targeting a specific GPU ID (#19342)

This commit is contained in:
baudneo
2025-08-18 17:43:53 -06:00
committed by GitHub
parent 83e9ae616a
commit 33f3ea3b59
11 changed files with 43 additions and 19 deletions

View File

@@ -112,9 +112,8 @@ class Embeddings:
self.embedding = JinaV2Embedding(
model_size=self.config.semantic_search.model_size,
requestor=self.requestor,
device="GPU"
if self.config.semantic_search.model_size == "large"
else "CPU",
device=config.semantic_search.device
or ("GPU" if config.semantic_search.model_size == "large" else "CPU"),
)
self.text_embedding = lambda input_data: self.embedding(
input_data, embedding_type="text"
@@ -131,7 +130,8 @@ class Embeddings:
self.vision_embedding = JinaV1ImageEmbedding(
model_size=config.semantic_search.model_size,
requestor=self.requestor,
device="GPU" if config.semantic_search.model_size == "large" else "CPU",
device=config.semantic_search.device
or ("GPU" if config.semantic_search.model_size == "large" else "CPU"),
)
def update_stats(self) -> None:

View File

@@ -9,6 +9,7 @@ from frigate.const import MODEL_CACHE_DIR
from frigate.log import redirect_output_to_logger
from frigate.util.downloader import ModelDownloader
from ...config import FaceRecognitionConfig
from .base_embedding import BaseEmbedding
from .runner import ONNXModelRunner
@@ -111,7 +112,7 @@ class FaceNetEmbedding(BaseEmbedding):
class ArcfaceEmbedding(BaseEmbedding):
def __init__(self):
def __init__(self, config: FaceRecognitionConfig):
super().__init__(
model_name="facedet",
model_file="arcface.onnx",
@@ -119,6 +120,7 @@ class ArcfaceEmbedding(BaseEmbedding):
"arcface.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/arcface.onnx",
},
)
self.config = config
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.tokenizer = None
self.feature_extractor = None
@@ -148,7 +150,7 @@ class ArcfaceEmbedding(BaseEmbedding):
self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file),
"GPU",
device=self.config.device or "GPU",
)
def _preprocess_inputs(self, raw_inputs):

View File

@@ -128,7 +128,6 @@ class JinaV1TextEmbedding(BaseEmbedding):
self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file),
self.device,
self.model_size,
)
def _preprocess_inputs(self, raw_inputs):
@@ -207,7 +206,6 @@ class JinaV1ImageEmbedding(BaseEmbedding):
self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file),
self.device,
self.model_size,
)
def _preprocess_inputs(self, raw_inputs):

View File

@@ -128,7 +128,6 @@ class JinaV2Embedding(BaseEmbedding):
self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file),
self.device,
self.model_size,
)
def _preprocess_image(self, image_data: bytes | Image.Image) -> np.ndarray: