mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-08-13 13:47:36 +02:00
face recognition: use configured device
This commit is contained in:
parent
58f509b379
commit
82bb8daf67
@ -267,7 +267,7 @@ class ArcFaceRecognizer(FaceRecognizer):
|
||||
def __init__(self, config: FrigateConfig):
|
||||
super().__init__(config)
|
||||
self.mean_embs: dict[int, np.ndarray] = {}
|
||||
self.face_embedder: ArcfaceEmbedding = ArcfaceEmbedding()
|
||||
self.face_embedder: ArcfaceEmbedding = ArcfaceEmbedding(config)
|
||||
self.model_builder_queue: queue.Queue | None = None
|
||||
|
||||
def clear(self) -> None:
|
||||
@ -368,4 +368,4 @@ class ArcFaceRecognizer(FaceRecognizer):
|
||||
score = confidence
|
||||
label = name
|
||||
|
||||
return label, round(score - blur_reduction, 2)
|
||||
return label, round(score - blur_reduction, 2)
|
@ -10,6 +10,7 @@ from frigate.util.downloader import ModelDownloader
|
||||
|
||||
from .base_embedding import BaseEmbedding
|
||||
from .runner import ONNXModelRunner
|
||||
from ...config import FrigateConfig
|
||||
|
||||
try:
|
||||
from tflite_runtime.interpreter import Interpreter
|
||||
@ -109,7 +110,7 @@ class FaceNetEmbedding(BaseEmbedding):
|
||||
|
||||
|
||||
class ArcfaceEmbedding(BaseEmbedding):
|
||||
def __init__(self):
|
||||
def __init__(self, config: FrigateConfig):
|
||||
super().__init__(
|
||||
model_name="facedet",
|
||||
model_file="arcface.onnx",
|
||||
@ -117,6 +118,7 @@ class ArcfaceEmbedding(BaseEmbedding):
|
||||
"arcface.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/arcface.onnx",
|
||||
},
|
||||
)
|
||||
self.config = config
|
||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||
self.tokenizer = None
|
||||
self.feature_extractor = None
|
||||
@ -146,7 +148,7 @@ class ArcfaceEmbedding(BaseEmbedding):
|
||||
|
||||
self.runner = ONNXModelRunner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
"GPU",
|
||||
device=self.config.face_recognition.device,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
@ -182,4 +184,4 @@ class ArcfaceEmbedding(BaseEmbedding):
|
||||
|
||||
frame = np.transpose(frame, (2, 0, 1))
|
||||
frame = np.expand_dims(frame, axis=0)
|
||||
return [{"data": frame}]
|
||||
return [{"data": frame}]
|
Loading…
Reference in New Issue
Block a user