Face recognition backend (#14495)

* Add basic config and face recognition table

* Reconfigure updates processing to handle face

* Crop frame to face box

* Implement face embedding calculation

* Get matching face embeddings

* Add support face recognition based on existing faces

* Use arcface face embeddings instead of generic embeddings model

* Add apis for managing faces

* Implement face uploading API

* Build out more APIs

* Add min area config

* Handle larger images

* Add more debug logs

* fix calculation

* Reduce timeout

* Small tweaks

* Use webp images

* Use facenet model
This commit is contained in:
Nicolas Mowen 2024-10-22 16:05:48 -06:00
parent 3394942e92
commit b7b7e1b78b
13 changed files with 364 additions and 42 deletions

View File

@ -8,6 +8,8 @@ imutils == 0.5.*
joserfc == 1.0.*
pathvalidate == 3.2.*
markupsafe == 2.1.*
python-multipart == 0.0.12
# General
mypy == 1.6.1
numpy == 1.26.*
onvif_zeep == 0.2.12

View File

@ -0,0 +1,56 @@
"""Object classification APIs."""
import logging
from fastapi import APIRouter, Request, UploadFile
from fastapi.responses import JSONResponse
from frigate.api.defs.tags import Tags
from frigate.embeddings import EmbeddingsContext
logger = logging.getLogger(__name__)
router = APIRouter(tags=[Tags.events])
@router.get("/faces")
def get_faces():
return JSONResponse(content={"message": "there are faces"})
@router.post("/faces/{name}")
async def register_face(request: Request, name: str, file: UploadFile):
# if not file.content_type.startswith("image"):
# return JSONResponse(
# status_code=400,
# content={
# "success": False,
# "message": "Only an image can be used to register a face.",
# },
# )
context: EmbeddingsContext = request.app.embeddings
context.register_face(name, await file.read())
return JSONResponse(
status_code=200,
content={"success": True, "message": "Successfully registered face."},
)
@router.delete("/faces")
def deregister_faces(request: Request, body: dict = None):
json: dict[str, any] = body or {}
list_of_ids = json.get("ids", "")
if not list_of_ids or len(list_of_ids) == 0:
return JSONResponse(
content=({"success": False, "message": "Not a valid list of ids"}),
status_code=404,
)
context: EmbeddingsContext = request.app.embeddings
context.delete_face_ids(list_of_ids)
return JSONResponse(
content=({"success": True, "message": "Successfully deleted faces."}),
status_code=200,
)

View File

@ -10,4 +10,5 @@ class Tags(Enum):
review = "Review"
export = "Export"
events = "Events"
classification = "classification"
auth = "Auth"

View File

@ -11,7 +11,16 @@ from starlette_context import middleware, plugins
from starlette_context.plugins import Plugin
from frigate.api import app as main_app
from frigate.api import auth, event, export, media, notification, preview, review
from frigate.api import (
auth,
classification,
event,
export,
media,
notification,
preview,
review,
)
from frigate.api.auth import get_jwt_secret, limiter
from frigate.comms.event_metadata_updater import (
EventMetadataPublisher,
@ -95,6 +104,7 @@ def create_fastapi_app(
# Routes
# Order of include_router matters: https://fastapi.tiangolo.com/tutorial/path-params/#order-matters
app.include_router(auth.router)
app.include_router(classification.router)
app.include_router(review.router)
app.include_router(main_app.router)
app.include_router(preview.router)

View File

@ -12,6 +12,7 @@ class EmbeddingsRequestEnum(Enum):
embed_description = "embed_description"
embed_thumbnail = "embed_thumbnail"
generate_search = "generate_search"
register_face = "register_face"
class EmbeddingsResponder:
@ -22,7 +23,7 @@ class EmbeddingsResponder:
def check_for_request(self, process: Callable) -> None:
while True: # load all messages that are queued
has_message, _, _ = zmq.select([self.socket], [], [], 0.1)
has_message, _, _ = zmq.select([self.socket], [], [], 0.01)
if not has_message:
break

View File

@ -4,7 +4,17 @@ from pydantic import Field
from .base import FrigateBaseModel
__all__ = ["SemanticSearchConfig"]
__all__ = ["FaceRecognitionConfig", "SemanticSearchConfig"]
class FaceRecognitionConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable face recognition.")
threshold: float = Field(
default=0.9, title="Face similarity score required to be considered a match."
)
min_area: int = Field(
default=500, title="Min area of face box to consider running face recognition."
)
class SemanticSearchConfig(FrigateBaseModel):
@ -12,6 +22,9 @@ class SemanticSearchConfig(FrigateBaseModel):
reindex: Optional[bool] = Field(
default=False, title="Reindex all detections on startup."
)
face_recognition: FaceRecognitionConfig = Field(
default_factory=FaceRecognitionConfig, title="Face recognition config."
)
model_size: str = Field(
default="small", title="The size of the embeddings model used."
)

View File

@ -5,8 +5,9 @@ DEFAULT_DB_PATH = f"{CONFIG_DIR}/frigate.db"
MODEL_CACHE_DIR = f"{CONFIG_DIR}/model_cache"
BASE_DIR = "/media/frigate"
CLIPS_DIR = f"{BASE_DIR}/clips"
RECORD_DIR = f"{BASE_DIR}/recordings"
EXPORT_DIR = f"{BASE_DIR}/exports"
FACE_DIR = f"{CLIPS_DIR}/faces"
RECORD_DIR = f"{BASE_DIR}/recordings"
BIRDSEYE_PIPE = "/tmp/cache/birdseye"
CACHE_DIR = "/tmp/cache"
FRIGATE_LOCALHOST = "http://127.0.0.1:5000"

View File

@ -29,6 +29,10 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
ids = ",".join(["?" for _ in event_ids])
self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids)
def delete_embeddings_face(self, face_ids: list[str]) -> None:
ids = ",".join(["?" for _ in face_ids])
self.execute_sql(f"DELETE FROM vec_faces WHERE id IN ({ids})", face_ids)
def drop_embeddings_tables(self) -> None:
self.execute_sql("""
DROP TABLE vec_descriptions;
@ -36,8 +40,11 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
self.execute_sql("""
DROP TABLE vec_thumbnails;
""")
self.execute_sql("""
DROP TABLE vec_faces;
""")
def create_embeddings_tables(self) -> None:
def create_embeddings_tables(self, face_recognition: bool) -> None:
"""Create vec0 virtual table for embeddings"""
self.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
@ -51,3 +58,11 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
description_embedding FLOAT[768] distance_metric=cosine
);
""")
if face_recognition:
self.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_faces USING vec0(
id TEXT PRIMARY KEY,
face_embedding FLOAT[128] distance_metric=cosine
);
""")

View File

@ -1,5 +1,6 @@
"""SQLite-vec embeddings database."""
import base64
import json
import logging
import multiprocessing as mp
@ -189,6 +190,28 @@ class EmbeddingsContext:
return results
def register_face(self, face_name: str, image_data: bytes) -> None:
self.requestor.send_data(
EmbeddingsRequestEnum.register_face.value,
{
"face_name": face_name,
"image": base64.b64encode(image_data).decode("ASCII"),
},
)
def get_face_ids(self, name: str) -> list[str]:
sql_query = f"""
SELECT
id
FROM vec_descriptions
WHERE id LIKE '%{name}%'
"""
return self.db.execute_sql(sql_query).fetchall()
def delete_face_ids(self, ids: list[str]) -> None:
self.db.delete_embeddings_face(ids)
def update_description(self, event_id: str, description: str) -> None:
self.requestor.send_data(
EmbeddingsRequestEnum.embed_description.value,

View File

@ -3,6 +3,8 @@
import base64
import logging
import os
import random
import string
import time
from numpy import ndarray
@ -12,6 +14,7 @@ from frigate.comms.inter_process import InterProcessRequestor
from frigate.config.semantic_search import SemanticSearchConfig
from frigate.const import (
CONFIG_DIR,
FACE_DIR,
UPDATE_EMBEDDINGS_REINDEX_PROGRESS,
UPDATE_MODEL_STATE,
)
@ -67,7 +70,7 @@ class Embeddings:
self.requestor = InterProcessRequestor()
# Create tables if they don't exist
self.db.create_embeddings_tables()
self.db.create_embeddings_tables(self.config.face_recognition.enabled)
models = [
"jinaai/jina-clip-v1-text_model_fp16.onnx",
@ -121,6 +124,21 @@ class Embeddings:
device="GPU" if config.model_size == "large" else "CPU",
)
self.face_embedding = None
if self.config.face_recognition.enabled:
self.face_embedding = GenericONNXEmbedding(
model_name="facenet",
model_file="facenet.onnx",
download_urls={
"facenet.onnx": "https://github.com/NicolasSM-001/faceNet.onnx-/raw/refs/heads/main/faceNet.onnx"
},
model_size="large",
model_type=ModelTypeEnum.face,
requestor=self.requestor,
device="GPU",
)
def embed_thumbnail(
self, event_id: str, thumbnail: bytes, upsert: bool = True
) -> ndarray:
@ -215,12 +233,40 @@ class Embeddings:
return embeddings
def embed_face(self, label: str, thumbnail: bytes, upsert: bool = False) -> ndarray:
embedding = self.face_embedding(thumbnail)[0]
if upsert:
rand_id = "".join(
random.choices(string.ascii_lowercase + string.digits, k=6)
)
id = f"{label}-{rand_id}"
# write face to library
folder = os.path.join(FACE_DIR, label)
file = os.path.join(folder, f"{id}.webp")
os.makedirs(folder, exist_ok=True)
# save face image
with open(file, "wb") as output:
output.write(thumbnail)
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_faces(id, face_embedding)
VALUES(?, ?)
""",
(id, serialize(embedding)),
)
return embedding
def reindex(self) -> None:
logger.info("Indexing tracked object embeddings...")
self.db.drop_embeddings_tables()
logger.debug("Dropped embeddings tables.")
self.db.create_embeddings_tables()
self.db.create_embeddings_tables(self.config.face_recognition.enabled)
logger.debug("Created embeddings tables.")
# Delete the saved stats file

View File

@ -31,6 +31,8 @@ warnings.filterwarnings(
disable_progress_bar()
logger = logging.getLogger(__name__)
FACE_EMBEDDING_SIZE = 160
class ModelTypeEnum(str, Enum):
face = "face"
@ -47,7 +49,7 @@ class GenericONNXEmbedding:
model_file: str,
download_urls: Dict[str, str],
model_size: str,
model_type: str,
model_type: ModelTypeEnum,
requestor: InterProcessRequestor,
tokenizer_file: Optional[str] = None,
device: str = "AUTO",
@ -57,7 +59,7 @@ class GenericONNXEmbedding:
self.tokenizer_file = tokenizer_file
self.requestor = requestor
self.download_urls = download_urls
self.model_type = model_type # 'text' or 'vision'
self.model_type = model_type
self.model_size = model_size
self.device = device
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
@ -93,6 +95,7 @@ class GenericONNXEmbedding:
def _download_model(self, path: str):
try:
file_name = os.path.basename(path)
if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path)
elif (
@ -101,6 +104,7 @@ class GenericONNXEmbedding:
):
if not os.path.exists(path + "/" + self.model_name):
logger.info(f"Downloading {self.model_name} tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True,
@ -131,8 +135,11 @@ class GenericONNXEmbedding:
self.downloader.wait_for_download()
if self.model_type == ModelTypeEnum.text:
self.tokenizer = self._load_tokenizer()
else:
elif self.model_type == ModelTypeEnum.vision:
self.feature_extractor = self._load_feature_extractor()
elif self.model_type == ModelTypeEnum.face:
self.feature_extractor = []
self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file),
self.device,
@ -172,16 +179,51 @@ class GenericONNXEmbedding:
self.feature_extractor(images=image, return_tensors="np")
for image in processed_images
]
elif self.model_type == ModelTypeEnum.face:
if isinstance(raw_inputs, list):
raise ValueError("Face embedding does not support batch inputs.")
pil = self._process_image(raw_inputs)
# handle images larger than input size
width, height = pil.size
if width != FACE_EMBEDDING_SIZE or height != FACE_EMBEDDING_SIZE:
if width > height:
new_height = int(((height / width) * FACE_EMBEDDING_SIZE) // 4 * 4)
pil = pil.resize((FACE_EMBEDDING_SIZE, new_height))
else:
new_width = int(((width / height) * FACE_EMBEDDING_SIZE) // 4 * 4)
pil = pil.resize((new_width, FACE_EMBEDDING_SIZE))
og = np.array(pil).astype(np.float32)
# Image must be FACE_EMBEDDING_SIZExFACE_EMBEDDING_SIZE
og_h, og_w, channels = og.shape
frame = np.full(
(FACE_EMBEDDING_SIZE, FACE_EMBEDDING_SIZE, channels),
(0, 0, 0),
dtype=np.float32,
)
# compute center offset
x_center = (FACE_EMBEDDING_SIZE - og_w) // 2
y_center = (FACE_EMBEDDING_SIZE - og_h) // 2
# copy img image into center of result image
frame[y_center : y_center + og_h, x_center : x_center + og_w] = og
frame = np.expand_dims(frame, axis=0)
return [{"image_input": frame}]
else:
raise ValueError(f"Unable to preprocess inputs for {self.model_type}")
def _process_image(self, image):
def _process_image(self, image, output: str = "RGB") -> Image.Image:
if isinstance(image, str):
if image.startswith("http"):
response = requests.get(image)
image = Image.open(BytesIO(response.content)).convert("RGB")
image = Image.open(BytesIO(response.content)).convert(output)
elif isinstance(image, bytes):
image = Image.open(BytesIO(image)).convert("RGB")
image = Image.open(BytesIO(image)).convert(output)
return image

View File

@ -9,6 +9,7 @@ from typing import Optional
import cv2
import numpy as np
import requests
from peewee import DoesNotExist
from playhouse.sqliteq import SqliteQueueDatabase
@ -20,12 +21,12 @@ from frigate.comms.event_metadata_updater import (
from frigate.comms.events_updater import EventEndSubscriber, EventUpdateSubscriber
from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig
from frigate.const import CLIPS_DIR, UPDATE_EVENT_DESCRIPTION
from frigate.const import CLIPS_DIR, FRIGATE_LOCALHOST, UPDATE_EVENT_DESCRIPTION
from frigate.events.types import EventTypeEnum
from frigate.genai import get_genai_client
from frigate.models import Event
from frigate.util.builtin import serialize
from frigate.util.image import SharedMemoryFrameManager, calculate_region
from frigate.util.image import SharedMemoryFrameManager, area, calculate_region
from .embeddings import Embeddings
@ -58,10 +59,17 @@ class EmbeddingMaintainer(threading.Thread):
)
self.embeddings_responder = EmbeddingsResponder()
self.frame_manager = SharedMemoryFrameManager()
# set face recognition conditions
self.face_recognition_enabled = (
self.config.semantic_search.face_recognition.enabled
)
self.requires_face_detection = "face" not in self.config.model.all_attributes
# create communication for updating event descriptions
self.requestor = InterProcessRequestor()
self.stop_event = stop_event
self.tracked_events = {}
self.tracked_events: dict[str, list[any]] = {}
self.genai_client = get_genai_client(config)
def run(self) -> None:
@ -101,6 +109,13 @@ class EmbeddingMaintainer(threading.Thread):
return serialize(
self.embeddings.text_embedding([data])[0], pack=False
)
elif topic == EmbeddingsRequestEnum.register_face.value:
self.embeddings.embed_face(
data["face_name"],
base64.b64decode(data["image"]),
upsert=True,
)
return None
except Exception as e:
logger.error(f"Unable to handle embeddings request {e}")
@ -108,7 +123,7 @@ class EmbeddingMaintainer(threading.Thread):
def _process_updates(self) -> None:
"""Process event updates"""
update = self.event_subscriber.check_for_update(timeout=0.1)
update = self.event_subscriber.check_for_update(timeout=0.01)
if update is None:
return
@ -119,24 +134,33 @@ class EmbeddingMaintainer(threading.Thread):
return
camera_config = self.config.cameras[camera]
# no need to save our own thumbnails if genai is not enabled
# or if the object has become stationary
if (
not camera_config.genai.enabled
or self.genai_client is None
or data["stationary"]
):
return
if data["id"] not in self.tracked_events:
self.tracked_events[data["id"]] = []
# no need to process updated objects if face recognition and genai are disabled
if not camera_config.genai.enabled and not self.face_recognition_enabled:
return
# Create our own thumbnail based on the bounding box and the frame time
try:
frame_id = f"{camera}{data['frame_time']}"
yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv)
except FileNotFoundError:
pass
if yuv_frame is None:
logger.debug(
"Unable to process object update because frame is unavailable."
)
return
if self.face_recognition_enabled:
self._process_face(data, yuv_frame)
# no need to save our own thumbnails if genai is not enabled
# or if the object has become stationary
if self.genai_client is not None and not data["stationary"]:
if data["id"] not in self.tracked_events:
self.tracked_events[data["id"]] = []
if yuv_frame is not None:
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
# Limit the number of thumbnails saved
@ -147,13 +171,11 @@ class EmbeddingMaintainer(threading.Thread):
self.tracked_events[data["id"]].append(data)
self.frame_manager.close(frame_id)
except FileNotFoundError:
pass
def _process_finalized(self) -> None:
"""Process the end of an event."""
while True:
ended = self.event_end_subscriber.check_for_update(timeout=0.1)
ended = self.event_end_subscriber.check_for_update(timeout=0.01)
if ended == None:
break
@ -243,7 +265,7 @@ class EmbeddingMaintainer(threading.Thread):
def _process_event_metadata(self):
# Check for regenerate description requests
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update(
timeout=0.1
timeout=0.01
)
if topic is None:
@ -252,6 +274,94 @@ class EmbeddingMaintainer(threading.Thread):
if event_id:
self.handle_regenerate_description(event_id, source)
def _search_face(self, query_embedding: bytes) -> list:
"""Search for the face most closely matching the embedding."""
sql_query = """
SELECT
id,
distance
FROM vec_faces
WHERE face_embedding MATCH ?
AND k = 10 ORDER BY distance
"""
return self.embeddings.db.execute_sql(sql_query, [query_embedding]).fetchall()
def _process_face(self, obj_data: dict[str, any], frame: np.ndarray) -> None:
"""Look for faces in image."""
# don't run for non person objects
if obj_data.get("label") != "person":
logger.debug("Not a processing face for non person object.")
return
# don't overwrite sub label for objects that have one
if obj_data.get("sub_label"):
logger.debug(
f"Not processing face due to existing sub label: {obj_data.get('sub_label')}."
)
return
face: Optional[dict[str, any]] = None
if self.requires_face_detection:
# TODO run cv2 face detection
pass
else:
# don't run for object without attributes
if not obj_data.get("current_attributes"):
logger.debug("No attributes to parse.")
return
attributes: list[dict[str, any]] = obj_data.get("current_attributes", [])
for attr in attributes:
if attr.get("label") != "face":
continue
if face is None or attr.get("score", 0.0) > face.get("score", 0.0):
face = attr
# no faces detected in this frame
if not face:
return
face_box = face.get("box")
# check that face is valid
if (
not face_box
or area(face_box) < self.config.semantic_search.face_recognition.min_area
):
logger.debug(f"Invalid face box {face}")
return
face_frame = cv2.cvtColor(frame, cv2.COLOR_YUV2BGR_I420)
face_frame = face_frame[face_box[1] : face_box[3], face_box[0] : face_box[2]]
ret, jpg = cv2.imencode(
".webp", face_frame, [int(cv2.IMWRITE_WEBP_QUALITY), 100]
)
if not ret:
logger.debug("Not processing face due to error creating cropped image.")
return
embedding = self.embeddings.embed_face("unknown", jpg.tobytes(), upsert=False)
query_embedding = serialize(embedding)
best_faces = self._search_face(query_embedding)
logger.debug(f"Detected best faces for person as: {best_faces}")
if not best_faces:
return
sub_label = str(best_faces[0][0]).split("-")[0]
score = 1.0 - best_faces[0][1]
if score < self.config.semantic_search.face_recognition.threshold:
return None
requests.post(
f"{FRIGATE_LOCALHOST}/api/events/{obj_data['id']}/sub_label",
json={"subLabel": sub_label, "subLabelScore": score},
)
def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]:
"""Return jpg thumbnail of a region of the frame."""
frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420)

View File

@ -101,7 +101,7 @@ class ModelDownloader:
self.download_complete.set()
@staticmethod
def download_from_url(url: str, save_path: str, silent: bool = False):
def download_from_url(url: str, save_path: str, silent: bool = False) -> Path:
temporary_filename = Path(save_path).with_name(
os.path.basename(save_path) + ".part"
)
@ -125,6 +125,8 @@ class ModelDownloader:
if not silent:
logger.info(f"Downloading complete: {url}")
return Path(save_path)
@staticmethod
def mark_files_state(
requestor: InterProcessRequestor,