Face recognition backend (#14495)

* Add basic config and face recognition table

* Reconfigure updates processing to handle face

* Crop frame to face box

* Implement face embedding calculation

* Get matching face embeddings

* Add support face recognition based on existing faces

* Use arcface face embeddings instead of generic embeddings model

* Add apis for managing faces

* Implement face uploading API

* Build out more APIs

* Add min area config

* Handle larger images

* Add more debug logs

* fix calculation

* Reduce timeout

* Small tweaks

* Use webp images

* Use facenet model
This commit is contained in:
Nicolas Mowen 2024-10-22 16:05:48 -06:00
parent d6071b3d1b
commit 13e90fc6e0
13 changed files with 365 additions and 45 deletions

View File

@ -10,6 +10,8 @@ imutils == 0.5.*
joserfc == 1.0.* joserfc == 1.0.*
pathvalidate == 3.2.* pathvalidate == 3.2.*
markupsafe == 2.1.* markupsafe == 2.1.*
python-multipart == 0.0.12
# General
mypy == 1.6.1 mypy == 1.6.1
numpy == 1.26.* numpy == 1.26.*
onvif_zeep == 0.2.12 onvif_zeep == 0.2.12

View File

@ -0,0 +1,56 @@
"""Object classification APIs."""
import logging
from fastapi import APIRouter, Request, UploadFile
from fastapi.responses import JSONResponse
from frigate.api.defs.tags import Tags
from frigate.embeddings import EmbeddingsContext
logger = logging.getLogger(__name__)
router = APIRouter(tags=[Tags.events])
@router.get("/faces")
def get_faces():
return JSONResponse(content={"message": "there are faces"})
@router.post("/faces/{name}")
async def register_face(request: Request, name: str, file: UploadFile):
# if not file.content_type.startswith("image"):
# return JSONResponse(
# status_code=400,
# content={
# "success": False,
# "message": "Only an image can be used to register a face.",
# },
# )
context: EmbeddingsContext = request.app.embeddings
context.register_face(name, await file.read())
return JSONResponse(
status_code=200,
content={"success": True, "message": "Successfully registered face."},
)
@router.delete("/faces")
def deregister_faces(request: Request, body: dict = None):
json: dict[str, any] = body or {}
list_of_ids = json.get("ids", "")
if not list_of_ids or len(list_of_ids) == 0:
return JSONResponse(
content=({"success": False, "message": "Not a valid list of ids"}),
status_code=404,
)
context: EmbeddingsContext = request.app.embeddings
context.delete_face_ids(list_of_ids)
return JSONResponse(
content=({"success": True, "message": "Successfully deleted faces."}),
status_code=200,
)

View File

@ -10,4 +10,5 @@ class Tags(Enum):
review = "Review" review = "Review"
export = "Export" export = "Export"
events = "Events" events = "Events"
classification = "classification"
auth = "Auth" auth = "Auth"

View File

@ -11,7 +11,16 @@ from starlette_context import middleware, plugins
from starlette_context.plugins import Plugin from starlette_context.plugins import Plugin
from frigate.api import app as main_app from frigate.api import app as main_app
from frigate.api import auth, event, export, media, notification, preview, review from frigate.api import (
auth,
classification,
event,
export,
media,
notification,
preview,
review,
)
from frigate.api.auth import get_jwt_secret, limiter from frigate.api.auth import get_jwt_secret, limiter
from frigate.comms.event_metadata_updater import ( from frigate.comms.event_metadata_updater import (
EventMetadataPublisher, EventMetadataPublisher,
@ -95,6 +104,7 @@ def create_fastapi_app(
# Routes # Routes
# Order of include_router matters: https://fastapi.tiangolo.com/tutorial/path-params/#order-matters # Order of include_router matters: https://fastapi.tiangolo.com/tutorial/path-params/#order-matters
app.include_router(auth.router) app.include_router(auth.router)
app.include_router(classification.router)
app.include_router(review.router) app.include_router(review.router)
app.include_router(main_app.router) app.include_router(main_app.router)
app.include_router(preview.router) app.include_router(preview.router)

View File

@ -12,6 +12,7 @@ class EmbeddingsRequestEnum(Enum):
embed_description = "embed_description" embed_description = "embed_description"
embed_thumbnail = "embed_thumbnail" embed_thumbnail = "embed_thumbnail"
generate_search = "generate_search" generate_search = "generate_search"
register_face = "register_face"
class EmbeddingsResponder: class EmbeddingsResponder:
@ -22,7 +23,7 @@ class EmbeddingsResponder:
def check_for_request(self, process: Callable) -> None: def check_for_request(self, process: Callable) -> None:
while True: # load all messages that are queued while True: # load all messages that are queued
has_message, _, _ = zmq.select([self.socket], [], [], 0.1) has_message, _, _ = zmq.select([self.socket], [], [], 0.01)
if not has_message: if not has_message:
break break

View File

@ -4,7 +4,17 @@ from pydantic import Field
from .base import FrigateBaseModel from .base import FrigateBaseModel
__all__ = ["SemanticSearchConfig"] __all__ = ["FaceRecognitionConfig", "SemanticSearchConfig"]
class FaceRecognitionConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable face recognition.")
threshold: float = Field(
default=0.9, title="Face similarity score required to be considered a match."
)
min_area: int = Field(
default=500, title="Min area of face box to consider running face recognition."
)
class SemanticSearchConfig(FrigateBaseModel): class SemanticSearchConfig(FrigateBaseModel):
@ -12,6 +22,9 @@ class SemanticSearchConfig(FrigateBaseModel):
reindex: Optional[bool] = Field( reindex: Optional[bool] = Field(
default=False, title="Reindex all detections on startup." default=False, title="Reindex all detections on startup."
) )
face_recognition: FaceRecognitionConfig = Field(
default_factory=FaceRecognitionConfig, title="Face recognition config."
)
model_size: str = Field( model_size: str = Field(
default="small", title="The size of the embeddings model used." default="small", title="The size of the embeddings model used."
) )

View File

@ -5,8 +5,9 @@ DEFAULT_DB_PATH = f"{CONFIG_DIR}/frigate.db"
MODEL_CACHE_DIR = f"{CONFIG_DIR}/model_cache" MODEL_CACHE_DIR = f"{CONFIG_DIR}/model_cache"
BASE_DIR = "/media/frigate" BASE_DIR = "/media/frigate"
CLIPS_DIR = f"{BASE_DIR}/clips" CLIPS_DIR = f"{BASE_DIR}/clips"
RECORD_DIR = f"{BASE_DIR}/recordings"
EXPORT_DIR = f"{BASE_DIR}/exports" EXPORT_DIR = f"{BASE_DIR}/exports"
FACE_DIR = f"{CLIPS_DIR}/faces"
RECORD_DIR = f"{BASE_DIR}/recordings"
BIRDSEYE_PIPE = "/tmp/cache/birdseye" BIRDSEYE_PIPE = "/tmp/cache/birdseye"
CACHE_DIR = "/tmp/cache" CACHE_DIR = "/tmp/cache"
FRIGATE_LOCALHOST = "http://127.0.0.1:5000" FRIGATE_LOCALHOST = "http://127.0.0.1:5000"

View File

@ -29,6 +29,10 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
ids = ",".join(["?" for _ in event_ids]) ids = ",".join(["?" for _ in event_ids])
self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids) self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids)
def delete_embeddings_face(self, face_ids: list[str]) -> None:
ids = ",".join(["?" for _ in face_ids])
self.execute_sql(f"DELETE FROM vec_faces WHERE id IN ({ids})", face_ids)
def drop_embeddings_tables(self) -> None: def drop_embeddings_tables(self) -> None:
self.execute_sql(""" self.execute_sql("""
DROP TABLE vec_descriptions; DROP TABLE vec_descriptions;
@ -36,8 +40,11 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
self.execute_sql(""" self.execute_sql("""
DROP TABLE vec_thumbnails; DROP TABLE vec_thumbnails;
""") """)
self.execute_sql("""
DROP TABLE vec_faces;
""")
def create_embeddings_tables(self) -> None: def create_embeddings_tables(self, face_recognition: bool) -> None:
"""Create vec0 virtual table for embeddings""" """Create vec0 virtual table for embeddings"""
self.execute_sql(""" self.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0( CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
@ -51,3 +58,11 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
description_embedding FLOAT[768] distance_metric=cosine description_embedding FLOAT[768] distance_metric=cosine
); );
""") """)
if face_recognition:
self.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_faces USING vec0(
id TEXT PRIMARY KEY,
face_embedding FLOAT[128] distance_metric=cosine
);
""")

View File

@ -1,5 +1,6 @@
"""SQLite-vec embeddings database.""" """SQLite-vec embeddings database."""
import base64
import json import json
import logging import logging
import multiprocessing as mp import multiprocessing as mp
@ -189,6 +190,28 @@ class EmbeddingsContext:
return results return results
def register_face(self, face_name: str, image_data: bytes) -> None:
self.requestor.send_data(
EmbeddingsRequestEnum.register_face.value,
{
"face_name": face_name,
"image": base64.b64encode(image_data).decode("ASCII"),
},
)
def get_face_ids(self, name: str) -> list[str]:
sql_query = f"""
SELECT
id
FROM vec_descriptions
WHERE id LIKE '%{name}%'
"""
return self.db.execute_sql(sql_query).fetchall()
def delete_face_ids(self, ids: list[str]) -> None:
self.db.delete_embeddings_face(ids)
def update_description(self, event_id: str, description: str) -> None: def update_description(self, event_id: str, description: str) -> None:
self.requestor.send_data( self.requestor.send_data(
EmbeddingsRequestEnum.embed_description.value, EmbeddingsRequestEnum.embed_description.value,

View File

@ -3,6 +3,8 @@
import base64 import base64
import logging import logging
import os import os
import random
import string
import time import time
from numpy import ndarray from numpy import ndarray
@ -12,6 +14,7 @@ from frigate.comms.inter_process import InterProcessRequestor
from frigate.config.semantic_search import SemanticSearchConfig from frigate.config.semantic_search import SemanticSearchConfig
from frigate.const import ( from frigate.const import (
CONFIG_DIR, CONFIG_DIR,
FACE_DIR,
UPDATE_EMBEDDINGS_REINDEX_PROGRESS, UPDATE_EMBEDDINGS_REINDEX_PROGRESS,
UPDATE_MODEL_STATE, UPDATE_MODEL_STATE,
) )
@ -67,7 +70,7 @@ class Embeddings:
self.requestor = InterProcessRequestor() self.requestor = InterProcessRequestor()
# Create tables if they don't exist # Create tables if they don't exist
self.db.create_embeddings_tables() self.db.create_embeddings_tables(self.config.face_recognition.enabled)
models = [ models = [
"jinaai/jina-clip-v1-text_model_fp16.onnx", "jinaai/jina-clip-v1-text_model_fp16.onnx",
@ -121,6 +124,21 @@ class Embeddings:
device="GPU" if config.model_size == "large" else "CPU", device="GPU" if config.model_size == "large" else "CPU",
) )
self.face_embedding = None
if self.config.face_recognition.enabled:
self.face_embedding = GenericONNXEmbedding(
model_name="facenet",
model_file="facenet.onnx",
download_urls={
"facenet.onnx": "https://github.com/NicolasSM-001/faceNet.onnx-/raw/refs/heads/main/faceNet.onnx"
},
model_size="large",
model_type=ModelTypeEnum.face,
requestor=self.requestor,
device="GPU",
)
def embed_thumbnail( def embed_thumbnail(
self, event_id: str, thumbnail: bytes, upsert: bool = True self, event_id: str, thumbnail: bytes, upsert: bool = True
) -> ndarray: ) -> ndarray:
@ -215,12 +233,40 @@ class Embeddings:
return embeddings return embeddings
def embed_face(self, label: str, thumbnail: bytes, upsert: bool = False) -> ndarray:
embedding = self.face_embedding(thumbnail)[0]
if upsert:
rand_id = "".join(
random.choices(string.ascii_lowercase + string.digits, k=6)
)
id = f"{label}-{rand_id}"
# write face to library
folder = os.path.join(FACE_DIR, label)
file = os.path.join(folder, f"{id}.webp")
os.makedirs(folder, exist_ok=True)
# save face image
with open(file, "wb") as output:
output.write(thumbnail)
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_faces(id, face_embedding)
VALUES(?, ?)
""",
(id, serialize(embedding)),
)
return embedding
def reindex(self) -> None: def reindex(self) -> None:
logger.info("Indexing tracked object embeddings...") logger.info("Indexing tracked object embeddings...")
self.db.drop_embeddings_tables() self.db.drop_embeddings_tables()
logger.debug("Dropped embeddings tables.") logger.debug("Dropped embeddings tables.")
self.db.create_embeddings_tables() self.db.create_embeddings_tables(self.config.face_recognition.enabled)
logger.debug("Created embeddings tables.") logger.debug("Created embeddings tables.")
# Delete the saved stats file # Delete the saved stats file

View File

@ -31,6 +31,8 @@ warnings.filterwarnings(
disable_progress_bar() disable_progress_bar()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
FACE_EMBEDDING_SIZE = 160
class ModelTypeEnum(str, Enum): class ModelTypeEnum(str, Enum):
face = "face" face = "face"
@ -47,7 +49,7 @@ class GenericONNXEmbedding:
model_file: str, model_file: str,
download_urls: Dict[str, str], download_urls: Dict[str, str],
model_size: str, model_size: str,
model_type: str, model_type: ModelTypeEnum,
requestor: InterProcessRequestor, requestor: InterProcessRequestor,
tokenizer_file: Optional[str] = None, tokenizer_file: Optional[str] = None,
device: str = "AUTO", device: str = "AUTO",
@ -57,7 +59,7 @@ class GenericONNXEmbedding:
self.tokenizer_file = tokenizer_file self.tokenizer_file = tokenizer_file
self.requestor = requestor self.requestor = requestor
self.download_urls = download_urls self.download_urls = download_urls
self.model_type = model_type # 'text' or 'vision' self.model_type = model_type
self.model_size = model_size self.model_size = model_size
self.device = device self.device = device
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
@ -93,6 +95,7 @@ class GenericONNXEmbedding:
def _download_model(self, path: str): def _download_model(self, path: str):
try: try:
file_name = os.path.basename(path) file_name = os.path.basename(path)
if file_name in self.download_urls: if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path) ModelDownloader.download_from_url(self.download_urls[file_name], path)
elif ( elif (
@ -101,6 +104,7 @@ class GenericONNXEmbedding:
): ):
if not os.path.exists(path + "/" + self.model_name): if not os.path.exists(path + "/" + self.model_name):
logger.info(f"Downloading {self.model_name} tokenizer") logger.info(f"Downloading {self.model_name} tokenizer")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
self.model_name, self.model_name,
trust_remote_code=True, trust_remote_code=True,
@ -131,8 +135,11 @@ class GenericONNXEmbedding:
self.downloader.wait_for_download() self.downloader.wait_for_download()
if self.model_type == ModelTypeEnum.text: if self.model_type == ModelTypeEnum.text:
self.tokenizer = self._load_tokenizer() self.tokenizer = self._load_tokenizer()
else: elif self.model_type == ModelTypeEnum.vision:
self.feature_extractor = self._load_feature_extractor() self.feature_extractor = self._load_feature_extractor()
elif self.model_type == ModelTypeEnum.face:
self.feature_extractor = []
self.runner = ONNXModelRunner( self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file), os.path.join(self.download_path, self.model_file),
self.device, self.device,
@ -172,16 +179,51 @@ class GenericONNXEmbedding:
self.feature_extractor(images=image, return_tensors="np") self.feature_extractor(images=image, return_tensors="np")
for image in processed_images for image in processed_images
] ]
elif self.model_type == ModelTypeEnum.face:
if isinstance(raw_inputs, list):
raise ValueError("Face embedding does not support batch inputs.")
pil = self._process_image(raw_inputs)
# handle images larger than input size
width, height = pil.size
if width != FACE_EMBEDDING_SIZE or height != FACE_EMBEDDING_SIZE:
if width > height:
new_height = int(((height / width) * FACE_EMBEDDING_SIZE) // 4 * 4)
pil = pil.resize((FACE_EMBEDDING_SIZE, new_height))
else:
new_width = int(((width / height) * FACE_EMBEDDING_SIZE) // 4 * 4)
pil = pil.resize((new_width, FACE_EMBEDDING_SIZE))
og = np.array(pil).astype(np.float32)
# Image must be FACE_EMBEDDING_SIZExFACE_EMBEDDING_SIZE
og_h, og_w, channels = og.shape
frame = np.full(
(FACE_EMBEDDING_SIZE, FACE_EMBEDDING_SIZE, channels),
(0, 0, 0),
dtype=np.float32,
)
# compute center offset
x_center = (FACE_EMBEDDING_SIZE - og_w) // 2
y_center = (FACE_EMBEDDING_SIZE - og_h) // 2
# copy img image into center of result image
frame[y_center : y_center + og_h, x_center : x_center + og_w] = og
frame = np.expand_dims(frame, axis=0)
return [{"image_input": frame}]
else: else:
raise ValueError(f"Unable to preprocess inputs for {self.model_type}") raise ValueError(f"Unable to preprocess inputs for {self.model_type}")
def _process_image(self, image): def _process_image(self, image, output: str = "RGB") -> Image.Image:
if isinstance(image, str): if isinstance(image, str):
if image.startswith("http"): if image.startswith("http"):
response = requests.get(image) response = requests.get(image)
image = Image.open(BytesIO(response.content)).convert("RGB") image = Image.open(BytesIO(response.content)).convert(output)
elif isinstance(image, bytes): elif isinstance(image, bytes):
image = Image.open(BytesIO(image)).convert("RGB") image = Image.open(BytesIO(image)).convert(output)
return image return image

View File

@ -9,6 +9,7 @@ from typing import Optional
import cv2 import cv2
import numpy as np import numpy as np
import requests
from peewee import DoesNotExist from peewee import DoesNotExist
from playhouse.sqliteq import SqliteQueueDatabase from playhouse.sqliteq import SqliteQueueDatabase
@ -20,13 +21,13 @@ from frigate.comms.event_metadata_updater import (
from frigate.comms.events_updater import EventEndSubscriber, EventUpdateSubscriber from frigate.comms.events_updater import EventEndSubscriber, EventUpdateSubscriber
from frigate.comms.inter_process import InterProcessRequestor from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.const import CLIPS_DIR, UPDATE_EVENT_DESCRIPTION from frigate.const import CLIPS_DIR, FRIGATE_LOCALHOST, UPDATE_EVENT_DESCRIPTION
from frigate.events.types import EventTypeEnum from frigate.events.types import EventTypeEnum
from frigate.genai import get_genai_client from frigate.genai import get_genai_client
from frigate.models import Event from frigate.models import Event
from frigate.types import TrackedObjectUpdateTypesEnum from frigate.types import TrackedObjectUpdateTypesEnum
from frigate.util.builtin import serialize from frigate.util.builtin import serialize
from frigate.util.image import SharedMemoryFrameManager, calculate_region from frigate.util.image import SharedMemoryFrameManager, area, calculate_region
from .embeddings import Embeddings from .embeddings import Embeddings
@ -59,10 +60,17 @@ class EmbeddingMaintainer(threading.Thread):
) )
self.embeddings_responder = EmbeddingsResponder() self.embeddings_responder = EmbeddingsResponder()
self.frame_manager = SharedMemoryFrameManager() self.frame_manager = SharedMemoryFrameManager()
# set face recognition conditions
self.face_recognition_enabled = (
self.config.semantic_search.face_recognition.enabled
)
self.requires_face_detection = "face" not in self.config.model.all_attributes
# create communication for updating event descriptions # create communication for updating event descriptions
self.requestor = InterProcessRequestor() self.requestor = InterProcessRequestor()
self.stop_event = stop_event self.stop_event = stop_event
self.tracked_events = {} self.tracked_events: dict[str, list[any]] = {}
self.genai_client = get_genai_client(config) self.genai_client = get_genai_client(config)
def run(self) -> None: def run(self) -> None:
@ -102,6 +110,13 @@ class EmbeddingMaintainer(threading.Thread):
return serialize( return serialize(
self.embeddings.text_embedding([data])[0], pack=False self.embeddings.text_embedding([data])[0], pack=False
) )
elif topic == EmbeddingsRequestEnum.register_face.value:
self.embeddings.embed_face(
data["face_name"],
base64.b64decode(data["image"]),
upsert=True,
)
return None
except Exception as e: except Exception as e:
logger.error(f"Unable to handle embeddings request {e}") logger.error(f"Unable to handle embeddings request {e}")
@ -109,7 +124,7 @@ class EmbeddingMaintainer(threading.Thread):
def _process_updates(self) -> None: def _process_updates(self) -> None:
"""Process event updates""" """Process event updates"""
update = self.event_subscriber.check_for_update(timeout=0.1) update = self.event_subscriber.check_for_update(timeout=0.01)
if update is None: if update is None:
return return
@ -120,25 +135,32 @@ class EmbeddingMaintainer(threading.Thread):
return return
camera_config = self.config.cameras[camera] camera_config = self.config.cameras[camera]
# no need to save our own thumbnails if genai is not enabled
# or if the object has become stationary
if (
not camera_config.genai.enabled
or self.genai_client is None
or data["stationary"]
):
return
if data["id"] not in self.tracked_events: # no need to process updated objects if face recognition and genai are disabled
self.tracked_events[data["id"]] = [] if not camera_config.genai.enabled and not self.face_recognition_enabled:
return
# Create our own thumbnail based on the bounding box and the frame time # Create our own thumbnail based on the bounding box and the frame time
try: try:
yuv_frame = self.frame_manager.get( yuv_frame = self.frame_manager.get(frame_name, camera_config.frame_shape_yuv)
frame_name, camera_config.frame_shape_yuv except FileNotFoundError:
) pass
if yuv_frame is None:
logger.debug(
"Unable to process object update because frame is unavailable."
)
return
if self.face_recognition_enabled:
self._process_face(data, yuv_frame)
# no need to save our own thumbnails if genai is not enabled
# or if the object has become stationary
if self.genai_client is not None and not data["stationary"]:
if data["id"] not in self.tracked_events:
self.tracked_events[data["id"]] = []
if yuv_frame is not None:
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"]) data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
# Limit the number of thumbnails saved # Limit the number of thumbnails saved
@ -149,13 +171,11 @@ class EmbeddingMaintainer(threading.Thread):
self.tracked_events[data["id"]].append(data) self.tracked_events[data["id"]].append(data)
self.frame_manager.close(frame_name) self.frame_manager.close(frame_name)
except FileNotFoundError:
pass
def _process_finalized(self) -> None: def _process_finalized(self) -> None:
"""Process the end of an event.""" """Process the end of an event."""
while True: while True:
ended = self.event_end_subscriber.check_for_update(timeout=0.1) ended = self.event_end_subscriber.check_for_update(timeout=0.01)
if ended == None: if ended == None:
break break
@ -245,7 +265,7 @@ class EmbeddingMaintainer(threading.Thread):
def _process_event_metadata(self): def _process_event_metadata(self):
# Check for regenerate description requests # Check for regenerate description requests
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update( (topic, event_id, source) = self.event_metadata_subscriber.check_for_update(
timeout=0.1 timeout=0.01
) )
if topic is None: if topic is None:
@ -254,6 +274,94 @@ class EmbeddingMaintainer(threading.Thread):
if event_id: if event_id:
self.handle_regenerate_description(event_id, source) self.handle_regenerate_description(event_id, source)
def _search_face(self, query_embedding: bytes) -> list:
"""Search for the face most closely matching the embedding."""
sql_query = """
SELECT
id,
distance
FROM vec_faces
WHERE face_embedding MATCH ?
AND k = 10 ORDER BY distance
"""
return self.embeddings.db.execute_sql(sql_query, [query_embedding]).fetchall()
def _process_face(self, obj_data: dict[str, any], frame: np.ndarray) -> None:
"""Look for faces in image."""
# don't run for non person objects
if obj_data.get("label") != "person":
logger.debug("Not a processing face for non person object.")
return
# don't overwrite sub label for objects that have one
if obj_data.get("sub_label"):
logger.debug(
f"Not processing face due to existing sub label: {obj_data.get('sub_label')}."
)
return
face: Optional[dict[str, any]] = None
if self.requires_face_detection:
# TODO run cv2 face detection
pass
else:
# don't run for object without attributes
if not obj_data.get("current_attributes"):
logger.debug("No attributes to parse.")
return
attributes: list[dict[str, any]] = obj_data.get("current_attributes", [])
for attr in attributes:
if attr.get("label") != "face":
continue
if face is None or attr.get("score", 0.0) > face.get("score", 0.0):
face = attr
# no faces detected in this frame
if not face:
return
face_box = face.get("box")
# check that face is valid
if (
not face_box
or area(face_box) < self.config.semantic_search.face_recognition.min_area
):
logger.debug(f"Invalid face box {face}")
return
face_frame = cv2.cvtColor(frame, cv2.COLOR_YUV2BGR_I420)
face_frame = face_frame[face_box[1] : face_box[3], face_box[0] : face_box[2]]
ret, jpg = cv2.imencode(
".webp", face_frame, [int(cv2.IMWRITE_WEBP_QUALITY), 100]
)
if not ret:
logger.debug("Not processing face due to error creating cropped image.")
return
embedding = self.embeddings.embed_face("unknown", jpg.tobytes(), upsert=False)
query_embedding = serialize(embedding)
best_faces = self._search_face(query_embedding)
logger.debug(f"Detected best faces for person as: {best_faces}")
if not best_faces:
return
sub_label = str(best_faces[0][0]).split("-")[0]
score = 1.0 - best_faces[0][1]
if score < self.config.semantic_search.face_recognition.threshold:
return None
requests.post(
f"{FRIGATE_LOCALHOST}/api/events/{obj_data['id']}/sub_label",
json={"subLabel": sub_label, "subLabelScore": score},
)
def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]: def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]:
"""Return jpg thumbnail of a region of the frame.""" """Return jpg thumbnail of a region of the frame."""
frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420) frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420)

View File

@ -101,7 +101,7 @@ class ModelDownloader:
self.download_complete.set() self.download_complete.set()
@staticmethod @staticmethod
def download_from_url(url: str, save_path: str, silent: bool = False): def download_from_url(url: str, save_path: str, silent: bool = False) -> Path:
temporary_filename = Path(save_path).with_name( temporary_filename = Path(save_path).with_name(
os.path.basename(save_path) + ".part" os.path.basename(save_path) + ".part"
) )
@ -125,6 +125,8 @@ class ModelDownloader:
if not silent: if not silent:
logger.info(f"Downloading complete: {url}") logger.info(f"Downloading complete: {url}")
return Path(save_path)
@staticmethod @staticmethod
def mark_files_state( def mark_files_state(
requestor: InterProcessRequestor, requestor: InterProcessRequestor,