mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-28 23:06:13 +02:00
Enrichments: Allow targeting a specific GPU ID (#19342)
This commit is contained in:
@@ -112,9 +112,8 @@ class Embeddings:
|
||||
self.embedding = JinaV2Embedding(
|
||||
model_size=self.config.semantic_search.model_size,
|
||||
requestor=self.requestor,
|
||||
device="GPU"
|
||||
if self.config.semantic_search.model_size == "large"
|
||||
else "CPU",
|
||||
device=config.semantic_search.device
|
||||
or ("GPU" if config.semantic_search.model_size == "large" else "CPU"),
|
||||
)
|
||||
self.text_embedding = lambda input_data: self.embedding(
|
||||
input_data, embedding_type="text"
|
||||
@@ -131,7 +130,8 @@ class Embeddings:
|
||||
self.vision_embedding = JinaV1ImageEmbedding(
|
||||
model_size=config.semantic_search.model_size,
|
||||
requestor=self.requestor,
|
||||
device="GPU" if config.semantic_search.model_size == "large" else "CPU",
|
||||
device=config.semantic_search.device
|
||||
or ("GPU" if config.semantic_search.model_size == "large" else "CPU"),
|
||||
)
|
||||
|
||||
def update_stats(self) -> None:
|
||||
|
||||
@@ -9,6 +9,7 @@ from frigate.const import MODEL_CACHE_DIR
|
||||
from frigate.log import redirect_output_to_logger
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
from ...config import FaceRecognitionConfig
|
||||
from .base_embedding import BaseEmbedding
|
||||
from .runner import ONNXModelRunner
|
||||
|
||||
@@ -111,7 +112,7 @@ class FaceNetEmbedding(BaseEmbedding):
|
||||
|
||||
|
||||
class ArcfaceEmbedding(BaseEmbedding):
|
||||
def __init__(self):
|
||||
def __init__(self, config: FaceRecognitionConfig):
|
||||
super().__init__(
|
||||
model_name="facedet",
|
||||
model_file="arcface.onnx",
|
||||
@@ -119,6 +120,7 @@ class ArcfaceEmbedding(BaseEmbedding):
|
||||
"arcface.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/arcface.onnx",
|
||||
},
|
||||
)
|
||||
self.config = config
|
||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||
self.tokenizer = None
|
||||
self.feature_extractor = None
|
||||
@@ -148,7 +150,7 @@ class ArcfaceEmbedding(BaseEmbedding):
|
||||
|
||||
self.runner = ONNXModelRunner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
"GPU",
|
||||
device=self.config.device or "GPU",
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
|
||||
@@ -128,7 +128,6 @@ class JinaV1TextEmbedding(BaseEmbedding):
|
||||
self.runner = ONNXModelRunner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
self.model_size,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
@@ -207,7 +206,6 @@ class JinaV1ImageEmbedding(BaseEmbedding):
|
||||
self.runner = ONNXModelRunner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
self.model_size,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
|
||||
@@ -128,7 +128,6 @@ class JinaV2Embedding(BaseEmbedding):
|
||||
self.runner = ONNXModelRunner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
self.model_size,
|
||||
)
|
||||
|
||||
def _preprocess_image(self, image_data: bytes | Image.Image) -> np.ndarray:
|
||||
|
||||
Reference in New Issue
Block a user