mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-26 19:06:11 +01:00
Use SVC to normalize and classify faces for recognition (#14835)
* Add margin to detected faces for embeddings * Standardize pixel values for face input * Use SVC to classify faces * Clear classifier when new face is added * Formatting * Add dependency
This commit is contained in:
parent
9d5d8ddbb2
commit
e65fb27f2d
@ -13,9 +13,7 @@ markupsafe == 2.1.*
|
||||
python-multipart == 0.0.12
|
||||
# General
|
||||
mypy == 1.6.1
|
||||
numpy == 1.26.*
|
||||
onvif_zeep == 0.2.12
|
||||
opencv-python-headless == 4.9.0.*
|
||||
paho-mqtt == 2.1.*
|
||||
pandas == 2.2.*
|
||||
peewee == 3.17.*
|
||||
@ -29,11 +27,15 @@ ruamel.yaml == 0.18.*
|
||||
tzlocal == 5.2
|
||||
requests == 2.32.*
|
||||
types-requests == 2.32.*
|
||||
scipy == 1.13.*
|
||||
norfair == 2.2.*
|
||||
setproctitle == 1.3.*
|
||||
ws4py == 0.5.*
|
||||
unidecode == 1.3.*
|
||||
# Image Manipulation
|
||||
numpy == 1.26.*
|
||||
opencv-python-headless == 4.9.0.*
|
||||
scipy == 1.13.*
|
||||
scikit-learn == 1.5.*
|
||||
# OpenVino & ONNX
|
||||
openvino == 2024.3.*
|
||||
onnxruntime-openvino == 1.19.* ; platform_machine == 'x86_64'
|
||||
|
@ -221,6 +221,9 @@ class GenericONNXEmbedding:
|
||||
# copy img image into center of result image
|
||||
frame[y_center : y_center + og_h, x_center : x_center + og_w] = og
|
||||
|
||||
# standardize pixel values across channels
|
||||
mean, std = frame.mean(), frame.std()
|
||||
frame = (frame - mean) / std
|
||||
frame = np.expand_dims(frame, axis=0)
|
||||
return [{"input_2": frame}]
|
||||
elif self.model_type == ModelTypeEnum.lpr_detect:
|
||||
|
@ -30,12 +30,12 @@ from frigate.models import Event
|
||||
from frigate.types import TrackedObjectUpdateTypesEnum
|
||||
from frigate.util.builtin import serialize
|
||||
from frigate.util.image import SharedMemoryFrameManager, area, calculate_region
|
||||
from frigate.util.model import FaceClassificationModel
|
||||
|
||||
from .embeddings import Embeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REQUIRED_FACES = 2
|
||||
MAX_THUMBNAILS = 10
|
||||
|
||||
|
||||
@ -68,6 +68,9 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
self.face_recognition_enabled = self.config.face_recognition.enabled
|
||||
self.requires_face_detection = "face" not in self.config.objects.all_objects
|
||||
self.detected_faces: dict[str, float] = {}
|
||||
self.face_classifier = (
|
||||
FaceClassificationModel(db) if self.face_recognition_enabled else None
|
||||
)
|
||||
|
||||
# create communication for updating event descriptions
|
||||
self.requestor = InterProcessRequestor()
|
||||
@ -138,13 +141,15 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
self.embeddings.text_embedding([data])[0], pack=False
|
||||
)
|
||||
elif topic == EmbeddingsRequestEnum.register_face.value:
|
||||
if not self.face_recognition_enabled:
|
||||
return False
|
||||
|
||||
if data.get("cropped"):
|
||||
self.embeddings.embed_face(
|
||||
data["face_name"],
|
||||
base64.b64decode(data["image"]),
|
||||
upsert=True,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
img = cv2.imdecode(
|
||||
np.frombuffer(
|
||||
@ -165,7 +170,8 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
data["face_name"], webp.tobytes(), upsert=True
|
||||
)
|
||||
|
||||
return False
|
||||
self.face_classifier.clear_classifier()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to handle embeddings request {e}")
|
||||
|
||||
@ -336,18 +342,6 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
if event_id:
|
||||
self.handle_regenerate_description(event_id, source)
|
||||
|
||||
def _search_face(self, query_embedding: bytes) -> list[tuple[str, float]]:
|
||||
"""Search for the face most closely matching the embedding."""
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
id,
|
||||
distance
|
||||
FROM vec_faces
|
||||
WHERE face_embedding MATCH ?
|
||||
AND k = {REQUIRED_FACES} ORDER BY distance
|
||||
"""
|
||||
return self.embeddings.db.execute_sql(sql_query, [query_embedding]).fetchall()
|
||||
|
||||
def _detect_face(self, input: np.ndarray) -> tuple[int, int, int, int]:
|
||||
"""Detect faces in input image."""
|
||||
self.face_detector.setInputSize((input.shape[1], input.shape[0]))
|
||||
@ -400,13 +394,21 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
|
||||
left, top, right, bottom = person_box
|
||||
person = rgb[top:bottom, left:right]
|
||||
face = self._detect_face(person)
|
||||
face_box = self._detect_face(person)
|
||||
|
||||
if not face:
|
||||
if not face_box:
|
||||
logger.debug("Detected no faces for person object.")
|
||||
return
|
||||
|
||||
face_frame = person[face[1] : face[3], face[0] : face[2]]
|
||||
margin = int((face_box[2] - face_box[0]) * 0.25)
|
||||
face_frame = person[
|
||||
max(0, face_box[1] - margin) : min(
|
||||
frame.shape[0], face_box[3] + margin
|
||||
),
|
||||
max(0, face_box[0] - margin) : min(
|
||||
frame.shape[1], face_box[2] + margin
|
||||
),
|
||||
]
|
||||
face_frame = cv2.cvtColor(face_frame, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
# don't run for object without attributes
|
||||
@ -434,8 +436,15 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
return
|
||||
|
||||
face_frame = cv2.cvtColor(frame, cv2.COLOR_YUV2BGR_I420)
|
||||
margin = int((face_box[2] - face_box[0]) * 0.25)
|
||||
|
||||
face_frame = face_frame[
|
||||
face_box[1] : face_box[3], face_box[0] : face_box[2]
|
||||
max(0, face_box[1] - margin) : min(
|
||||
frame.shape[0], face_box[3] + margin
|
||||
),
|
||||
max(0, face_box[0] - margin) : min(
|
||||
frame.shape[1], face_box[2] + margin
|
||||
),
|
||||
]
|
||||
|
||||
ret, webp = cv2.imencode(
|
||||
@ -446,34 +455,23 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
logger.debug("Not processing face due to error creating cropped image.")
|
||||
return
|
||||
|
||||
embedding = self.embeddings.embed_face("unknown", webp.tobytes(), upsert=False)
|
||||
query_embedding = serialize(embedding)
|
||||
best_faces = self._search_face(query_embedding)
|
||||
logger.debug(f"Detected best faces for person as: {best_faces}")
|
||||
embedding = self.embeddings.embed_face("nick", webp.tobytes(), upsert=True)
|
||||
res = self.face_classifier.classify_face(embedding)
|
||||
|
||||
if not best_faces or len(best_faces) < REQUIRED_FACES:
|
||||
logger.debug(f"{len(best_faces)} < {REQUIRED_FACES} min required faces.")
|
||||
if not res:
|
||||
return
|
||||
|
||||
sub_label = str(best_faces[0][0]).split("-")[0]
|
||||
avg_score = 0
|
||||
sub_label, score = res
|
||||
|
||||
for face in best_faces:
|
||||
score = 1.0 - face[1]
|
||||
logger.debug(
|
||||
f"Detected best face for person as: {sub_label} with score {score}"
|
||||
)
|
||||
|
||||
if face[0].split("-")[0] != sub_label:
|
||||
logger.debug("Detected multiple faces, result is not valid.")
|
||||
return
|
||||
|
||||
avg_score += score
|
||||
|
||||
avg_score = round(avg_score / REQUIRED_FACES, 2)
|
||||
|
||||
if avg_score < self.config.face_recognition.threshold or (
|
||||
id in self.detected_faces and avg_score <= self.detected_faces[id]
|
||||
if score < self.config.face_recognition.threshold or (
|
||||
id in self.detected_faces and score <= self.detected_faces[id]
|
||||
):
|
||||
logger.debug(
|
||||
f"Recognized face score {avg_score} is less than threshold ({self.config.face_recognition.threshold}) / previous face score ({self.detected_faces.get(id)})."
|
||||
f"Recognized face score {score} is less than threshold ({self.config.face_recognition.threshold}) / previous face score ({self.detected_faces.get(id)})."
|
||||
)
|
||||
return
|
||||
|
||||
@ -482,12 +480,12 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
json={
|
||||
"camera": obj_data.get("camera"),
|
||||
"subLabel": sub_label,
|
||||
"subLabelScore": avg_score,
|
||||
"subLabelScore": score,
|
||||
},
|
||||
)
|
||||
|
||||
if resp.status_code == 200:
|
||||
self.detected_faces[id] = avg_score
|
||||
self.detected_faces[id] = score
|
||||
|
||||
def _detect_license_plate(self, input: np.ndarray) -> tuple[int, int, int, int]:
|
||||
"""Return the dimensions of the input image as [x, y, width, height]."""
|
||||
|
@ -2,9 +2,15 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
from sklearn.preprocessing import LabelEncoder, Normalizer
|
||||
from sklearn.svm import SVC
|
||||
|
||||
from frigate.util.builtin import deserialize
|
||||
|
||||
try:
|
||||
import openvino as ov
|
||||
@ -148,3 +154,41 @@ class ONNXModelRunner:
|
||||
return [infer_request.get_output_tensor().data]
|
||||
elif self.type == "ort":
|
||||
return self.ort.run(None, input)
|
||||
|
||||
|
||||
class FaceClassificationModel:
|
||||
def __init__(self, db: SqliteQueueDatabase):
|
||||
self.db = db
|
||||
self.labeler: Optional[LabelEncoder] = None
|
||||
self.classifier: Optional[SVC] = None
|
||||
|
||||
def __build_classifier(self) -> None:
|
||||
faces: list[tuple[str, bytes]] = self.db.execute_sql(
|
||||
"SELECT id, face_embedding FROM vec_faces"
|
||||
).fetchall()
|
||||
embeddings = np.array([deserialize(f[1]) for f in faces])
|
||||
self.labeler = LabelEncoder()
|
||||
norms = Normalizer(norm="l2").transform(embeddings)
|
||||
labels = self.labeler.fit_transform([f[0].split("-")[0] for f in faces])
|
||||
self.classifier = SVC(kernel="linear", probability=True)
|
||||
self.classifier.fit(norms, labels)
|
||||
|
||||
def clear_classifier(self) -> None:
|
||||
self.classifier = None
|
||||
self.labeler = None
|
||||
|
||||
def classify_face(self, embedding: np.ndarray) -> Optional[tuple[str, float]]:
|
||||
if not self.classifier:
|
||||
self.__build_classifier()
|
||||
|
||||
res = self.classifier.predict([embedding])
|
||||
|
||||
if not res:
|
||||
return None
|
||||
|
||||
label = res[0]
|
||||
probabilities = self.classifier.predict_proba([embedding])[0]
|
||||
return (
|
||||
self.labeler.inverse_transform([label])[0],
|
||||
round(probabilities[label], 2),
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user