diff --git a/docker/main/requirements-wheels.txt b/docker/main/requirements-wheels.txt index 4db88ccd2..804f6135e 100644 --- a/docker/main/requirements-wheels.txt +++ b/docker/main/requirements-wheels.txt @@ -10,6 +10,8 @@ imutils == 0.5.* joserfc == 1.0.* pathvalidate == 3.2.* markupsafe == 2.1.* +python-multipart == 0.0.12 +# General mypy == 1.6.1 numpy == 1.26.* onvif_zeep == 0.2.12 diff --git a/frigate/api/classification.py b/frigate/api/classification.py new file mode 100644 index 000000000..d862008c8 --- /dev/null +++ b/frigate/api/classification.py @@ -0,0 +1,56 @@ +"""Object classification APIs.""" + +import logging + +from fastapi import APIRouter, Request, UploadFile +from fastapi.responses import JSONResponse + +from frigate.api.defs.tags import Tags +from frigate.embeddings import EmbeddingsContext + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=[Tags.events]) + + +@router.get("/faces") +def get_faces(): + return JSONResponse(content={"message": "there are faces"}) + + +@router.post("/faces/{name}") +async def register_face(request: Request, name: str, file: UploadFile): + # if not file.content_type.startswith("image"): + # return JSONResponse( + # status_code=400, + # content={ + # "success": False, + # "message": "Only an image can be used to register a face.", + # }, + # ) + + context: EmbeddingsContext = request.app.embeddings + context.register_face(name, await file.read()) + return JSONResponse( + status_code=200, + content={"success": True, "message": "Successfully registered face."}, + ) + + +@router.delete("/faces") +def deregister_faces(request: Request, body: dict = None): + json: dict[str, any] = body or {} + list_of_ids = json.get("ids", "") + + if not list_of_ids or len(list_of_ids) == 0: + return JSONResponse( + content=({"success": False, "message": "Not a valid list of ids"}), + status_code=404, + ) + + context: EmbeddingsContext = request.app.embeddings + context.delete_face_ids(list_of_ids) + return JSONResponse( + content=({"success": True, "message": "Successfully deleted faces."}), + status_code=200, + ) diff --git a/frigate/api/defs/tags.py b/frigate/api/defs/tags.py index 80faf255c..9e61da9e9 100644 --- a/frigate/api/defs/tags.py +++ b/frigate/api/defs/tags.py @@ -10,4 +10,5 @@ class Tags(Enum): review = "Review" export = "Export" events = "Events" + classification = "classification" auth = "Auth" diff --git a/frigate/api/fastapi_app.py b/frigate/api/fastapi_app.py index e3542458e..942964d58 100644 --- a/frigate/api/fastapi_app.py +++ b/frigate/api/fastapi_app.py @@ -11,7 +11,16 @@ from starlette_context import middleware, plugins from starlette_context.plugins import Plugin from frigate.api import app as main_app -from frigate.api import auth, event, export, media, notification, preview, review +from frigate.api import ( + auth, + classification, + event, + export, + media, + notification, + preview, + review, +) from frigate.api.auth import get_jwt_secret, limiter from frigate.comms.event_metadata_updater import ( EventMetadataPublisher, @@ -95,6 +104,7 @@ def create_fastapi_app( # Routes # Order of include_router matters: https://fastapi.tiangolo.com/tutorial/path-params/#order-matters app.include_router(auth.router) + app.include_router(classification.router) app.include_router(review.router) app.include_router(main_app.router) app.include_router(preview.router) diff --git a/frigate/comms/embeddings_updater.py b/frigate/comms/embeddings_updater.py index 9a13525f8..095f33fde 100644 --- a/frigate/comms/embeddings_updater.py +++ b/frigate/comms/embeddings_updater.py @@ -12,6 +12,7 @@ class EmbeddingsRequestEnum(Enum): embed_description = "embed_description" embed_thumbnail = "embed_thumbnail" generate_search = "generate_search" + register_face = "register_face" class EmbeddingsResponder: @@ -22,7 +23,7 @@ class EmbeddingsResponder: def check_for_request(self, process: Callable) -> None: while True: # load all messages that are queued - has_message, _, _ = zmq.select([self.socket], [], [], 0.1) + has_message, _, _ = zmq.select([self.socket], [], [], 0.01) if not has_message: break diff --git a/frigate/config/semantic_search.py b/frigate/config/semantic_search.py index 2891050a1..32ff8cf3c 100644 --- a/frigate/config/semantic_search.py +++ b/frigate/config/semantic_search.py @@ -4,7 +4,17 @@ from pydantic import Field from .base import FrigateBaseModel -__all__ = ["SemanticSearchConfig"] +__all__ = ["FaceRecognitionConfig", "SemanticSearchConfig"] + + +class FaceRecognitionConfig(FrigateBaseModel): + enabled: bool = Field(default=False, title="Enable face recognition.") + threshold: float = Field( + default=0.9, title="Face similarity score required to be considered a match." + ) + min_area: int = Field( + default=500, title="Min area of face box to consider running face recognition." + ) class SemanticSearchConfig(FrigateBaseModel): @@ -12,6 +22,9 @@ class SemanticSearchConfig(FrigateBaseModel): reindex: Optional[bool] = Field( default=False, title="Reindex all detections on startup." ) + face_recognition: FaceRecognitionConfig = Field( + default_factory=FaceRecognitionConfig, title="Face recognition config." + ) model_size: str = Field( default="small", title="The size of the embeddings model used." ) diff --git a/frigate/const.py b/frigate/const.py index c83b10e73..41a2fbc15 100644 --- a/frigate/const.py +++ b/frigate/const.py @@ -5,8 +5,9 @@ DEFAULT_DB_PATH = f"{CONFIG_DIR}/frigate.db" MODEL_CACHE_DIR = f"{CONFIG_DIR}/model_cache" BASE_DIR = "/media/frigate" CLIPS_DIR = f"{BASE_DIR}/clips" -RECORD_DIR = f"{BASE_DIR}/recordings" EXPORT_DIR = f"{BASE_DIR}/exports" +FACE_DIR = f"{CLIPS_DIR}/faces" +RECORD_DIR = f"{BASE_DIR}/recordings" BIRDSEYE_PIPE = "/tmp/cache/birdseye" CACHE_DIR = "/tmp/cache" FRIGATE_LOCALHOST = "http://127.0.0.1:5000" diff --git a/frigate/db/sqlitevecq.py b/frigate/db/sqlitevecq.py index ccb75ae54..1447fd48f 100644 --- a/frigate/db/sqlitevecq.py +++ b/frigate/db/sqlitevecq.py @@ -29,6 +29,10 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase): ids = ",".join(["?" for _ in event_ids]) self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids) + def delete_embeddings_face(self, face_ids: list[str]) -> None: + ids = ",".join(["?" for _ in face_ids]) + self.execute_sql(f"DELETE FROM vec_faces WHERE id IN ({ids})", face_ids) + def drop_embeddings_tables(self) -> None: self.execute_sql(""" DROP TABLE vec_descriptions; @@ -36,8 +40,11 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase): self.execute_sql(""" DROP TABLE vec_thumbnails; """) + self.execute_sql(""" + DROP TABLE vec_faces; + """) - def create_embeddings_tables(self) -> None: + def create_embeddings_tables(self, face_recognition: bool) -> None: """Create vec0 virtual table for embeddings""" self.execute_sql(""" CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0( @@ -51,3 +58,11 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase): description_embedding FLOAT[768] distance_metric=cosine ); """) + + if face_recognition: + self.execute_sql(""" + CREATE VIRTUAL TABLE IF NOT EXISTS vec_faces USING vec0( + id TEXT PRIMARY KEY, + face_embedding FLOAT[128] distance_metric=cosine + ); + """) diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py index 7f2e1a10c..235b15df3 100644 --- a/frigate/embeddings/__init__.py +++ b/frigate/embeddings/__init__.py @@ -1,5 +1,6 @@ """SQLite-vec embeddings database.""" +import base64 import json import logging import multiprocessing as mp @@ -189,6 +190,28 @@ class EmbeddingsContext: return results + def register_face(self, face_name: str, image_data: bytes) -> None: + self.requestor.send_data( + EmbeddingsRequestEnum.register_face.value, + { + "face_name": face_name, + "image": base64.b64encode(image_data).decode("ASCII"), + }, + ) + + def get_face_ids(self, name: str) -> list[str]: + sql_query = f""" + SELECT + id + FROM vec_descriptions + WHERE id LIKE '%{name}%' + """ + + return self.db.execute_sql(sql_query).fetchall() + + def delete_face_ids(self, ids: list[str]) -> None: + self.db.delete_embeddings_face(ids) + def update_description(self, event_id: str, description: str) -> None: self.requestor.send_data( EmbeddingsRequestEnum.embed_description.value, diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index d77a9eecf..6b0f94ca9 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -3,6 +3,8 @@ import base64 import logging import os +import random +import string import time from numpy import ndarray @@ -12,6 +14,7 @@ from frigate.comms.inter_process import InterProcessRequestor from frigate.config.semantic_search import SemanticSearchConfig from frigate.const import ( CONFIG_DIR, + FACE_DIR, UPDATE_EMBEDDINGS_REINDEX_PROGRESS, UPDATE_MODEL_STATE, ) @@ -67,7 +70,7 @@ class Embeddings: self.requestor = InterProcessRequestor() # Create tables if they don't exist - self.db.create_embeddings_tables() + self.db.create_embeddings_tables(self.config.face_recognition.enabled) models = [ "jinaai/jina-clip-v1-text_model_fp16.onnx", @@ -121,6 +124,21 @@ class Embeddings: device="GPU" if config.model_size == "large" else "CPU", ) + self.face_embedding = None + + if self.config.face_recognition.enabled: + self.face_embedding = GenericONNXEmbedding( + model_name="facenet", + model_file="facenet.onnx", + download_urls={ + "facenet.onnx": "https://github.com/NicolasSM-001/faceNet.onnx-/raw/refs/heads/main/faceNet.onnx" + }, + model_size="large", + model_type=ModelTypeEnum.face, + requestor=self.requestor, + device="GPU", + ) + def embed_thumbnail( self, event_id: str, thumbnail: bytes, upsert: bool = True ) -> ndarray: @@ -215,12 +233,40 @@ class Embeddings: return embeddings + def embed_face(self, label: str, thumbnail: bytes, upsert: bool = False) -> ndarray: + embedding = self.face_embedding(thumbnail)[0] + + if upsert: + rand_id = "".join( + random.choices(string.ascii_lowercase + string.digits, k=6) + ) + id = f"{label}-{rand_id}" + + # write face to library + folder = os.path.join(FACE_DIR, label) + file = os.path.join(folder, f"{id}.webp") + os.makedirs(folder, exist_ok=True) + + # save face image + with open(file, "wb") as output: + output.write(thumbnail) + + self.db.execute_sql( + """ + INSERT OR REPLACE INTO vec_faces(id, face_embedding) + VALUES(?, ?) + """, + (id, serialize(embedding)), + ) + + return embedding + def reindex(self) -> None: logger.info("Indexing tracked object embeddings...") self.db.drop_embeddings_tables() logger.debug("Dropped embeddings tables.") - self.db.create_embeddings_tables() + self.db.create_embeddings_tables(self.config.face_recognition.enabled) logger.debug("Created embeddings tables.") # Delete the saved stats file diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py index 6ea495a30..9fc71d502 100644 --- a/frigate/embeddings/functions/onnx.py +++ b/frigate/embeddings/functions/onnx.py @@ -31,6 +31,8 @@ warnings.filterwarnings( disable_progress_bar() logger = logging.getLogger(__name__) +FACE_EMBEDDING_SIZE = 160 + class ModelTypeEnum(str, Enum): face = "face" @@ -47,7 +49,7 @@ class GenericONNXEmbedding: model_file: str, download_urls: Dict[str, str], model_size: str, - model_type: str, + model_type: ModelTypeEnum, requestor: InterProcessRequestor, tokenizer_file: Optional[str] = None, device: str = "AUTO", @@ -57,7 +59,7 @@ class GenericONNXEmbedding: self.tokenizer_file = tokenizer_file self.requestor = requestor self.download_urls = download_urls - self.model_type = model_type # 'text' or 'vision' + self.model_type = model_type self.model_size = model_size self.device = device self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) @@ -93,6 +95,7 @@ class GenericONNXEmbedding: def _download_model(self, path: str): try: file_name = os.path.basename(path) + if file_name in self.download_urls: ModelDownloader.download_from_url(self.download_urls[file_name], path) elif ( @@ -101,6 +104,7 @@ class GenericONNXEmbedding: ): if not os.path.exists(path + "/" + self.model_name): logger.info(f"Downloading {self.model_name} tokenizer") + tokenizer = AutoTokenizer.from_pretrained( self.model_name, trust_remote_code=True, @@ -131,8 +135,11 @@ class GenericONNXEmbedding: self.downloader.wait_for_download() if self.model_type == ModelTypeEnum.text: self.tokenizer = self._load_tokenizer() - else: + elif self.model_type == ModelTypeEnum.vision: self.feature_extractor = self._load_feature_extractor() + elif self.model_type == ModelTypeEnum.face: + self.feature_extractor = [] + self.runner = ONNXModelRunner( os.path.join(self.download_path, self.model_file), self.device, @@ -172,16 +179,51 @@ class GenericONNXEmbedding: self.feature_extractor(images=image, return_tensors="np") for image in processed_images ] + elif self.model_type == ModelTypeEnum.face: + if isinstance(raw_inputs, list): + raise ValueError("Face embedding does not support batch inputs.") + + pil = self._process_image(raw_inputs) + + # handle images larger than input size + width, height = pil.size + if width != FACE_EMBEDDING_SIZE or height != FACE_EMBEDDING_SIZE: + if width > height: + new_height = int(((height / width) * FACE_EMBEDDING_SIZE) // 4 * 4) + pil = pil.resize((FACE_EMBEDDING_SIZE, new_height)) + else: + new_width = int(((width / height) * FACE_EMBEDDING_SIZE) // 4 * 4) + pil = pil.resize((new_width, FACE_EMBEDDING_SIZE)) + + og = np.array(pil).astype(np.float32) + + # Image must be FACE_EMBEDDING_SIZExFACE_EMBEDDING_SIZE + og_h, og_w, channels = og.shape + frame = np.full( + (FACE_EMBEDDING_SIZE, FACE_EMBEDDING_SIZE, channels), + (0, 0, 0), + dtype=np.float32, + ) + + # compute center offset + x_center = (FACE_EMBEDDING_SIZE - og_w) // 2 + y_center = (FACE_EMBEDDING_SIZE - og_h) // 2 + + # copy img image into center of result image + frame[y_center : y_center + og_h, x_center : x_center + og_w] = og + + frame = np.expand_dims(frame, axis=0) + return [{"image_input": frame}] else: raise ValueError(f"Unable to preprocess inputs for {self.model_type}") - def _process_image(self, image): + def _process_image(self, image, output: str = "RGB") -> Image.Image: if isinstance(image, str): if image.startswith("http"): response = requests.get(image) - image = Image.open(BytesIO(response.content)).convert("RGB") + image = Image.open(BytesIO(response.content)).convert(output) elif isinstance(image, bytes): - image = Image.open(BytesIO(image)).convert("RGB") + image = Image.open(BytesIO(image)).convert(output) return image diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index d58a7f431..104d44bbc 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -9,6 +9,7 @@ from typing import Optional import cv2 import numpy as np +import requests from peewee import DoesNotExist from playhouse.sqliteq import SqliteQueueDatabase @@ -20,13 +21,13 @@ from frigate.comms.event_metadata_updater import ( from frigate.comms.events_updater import EventEndSubscriber, EventUpdateSubscriber from frigate.comms.inter_process import InterProcessRequestor from frigate.config import FrigateConfig -from frigate.const import CLIPS_DIR, UPDATE_EVENT_DESCRIPTION +from frigate.const import CLIPS_DIR, FRIGATE_LOCALHOST, UPDATE_EVENT_DESCRIPTION from frigate.events.types import EventTypeEnum from frigate.genai import get_genai_client from frigate.models import Event from frigate.types import TrackedObjectUpdateTypesEnum from frigate.util.builtin import serialize -from frigate.util.image import SharedMemoryFrameManager, calculate_region +from frigate.util.image import SharedMemoryFrameManager, area, calculate_region from .embeddings import Embeddings @@ -59,10 +60,17 @@ class EmbeddingMaintainer(threading.Thread): ) self.embeddings_responder = EmbeddingsResponder() self.frame_manager = SharedMemoryFrameManager() + + # set face recognition conditions + self.face_recognition_enabled = ( + self.config.semantic_search.face_recognition.enabled + ) + self.requires_face_detection = "face" not in self.config.model.all_attributes + # create communication for updating event descriptions self.requestor = InterProcessRequestor() self.stop_event = stop_event - self.tracked_events = {} + self.tracked_events: dict[str, list[any]] = {} self.genai_client = get_genai_client(config) def run(self) -> None: @@ -102,6 +110,13 @@ class EmbeddingMaintainer(threading.Thread): return serialize( self.embeddings.text_embedding([data])[0], pack=False ) + elif topic == EmbeddingsRequestEnum.register_face.value: + self.embeddings.embed_face( + data["face_name"], + base64.b64decode(data["image"]), + upsert=True, + ) + return None except Exception as e: logger.error(f"Unable to handle embeddings request {e}") @@ -109,7 +124,7 @@ class EmbeddingMaintainer(threading.Thread): def _process_updates(self) -> None: """Process event updates""" - update = self.event_subscriber.check_for_update(timeout=0.1) + update = self.event_subscriber.check_for_update(timeout=0.01) if update is None: return @@ -120,42 +135,47 @@ class EmbeddingMaintainer(threading.Thread): return camera_config = self.config.cameras[camera] - # no need to save our own thumbnails if genai is not enabled - # or if the object has become stationary - if ( - not camera_config.genai.enabled - or self.genai_client is None - or data["stationary"] - ): - return - if data["id"] not in self.tracked_events: - self.tracked_events[data["id"]] = [] + # no need to process updated objects if face recognition and genai are disabled + if not camera_config.genai.enabled and not self.face_recognition_enabled: + return # Create our own thumbnail based on the bounding box and the frame time try: - yuv_frame = self.frame_manager.get( - frame_name, camera_config.frame_shape_yuv - ) - - if yuv_frame is not None: - data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"]) - - # Limit the number of thumbnails saved - if len(self.tracked_events[data["id"]]) >= MAX_THUMBNAILS: - # Always keep the first thumbnail for the event - self.tracked_events[data["id"]].pop(1) - - self.tracked_events[data["id"]].append(data) - - self.frame_manager.close(frame_name) + yuv_frame = self.frame_manager.get(frame_name, camera_config.frame_shape_yuv) except FileNotFoundError: pass + if yuv_frame is None: + logger.debug( + "Unable to process object update because frame is unavailable." + ) + return + + if self.face_recognition_enabled: + self._process_face(data, yuv_frame) + + # no need to save our own thumbnails if genai is not enabled + # or if the object has become stationary + if self.genai_client is not None and not data["stationary"]: + if data["id"] not in self.tracked_events: + self.tracked_events[data["id"]] = [] + + data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"]) + + # Limit the number of thumbnails saved + if len(self.tracked_events[data["id"]]) >= MAX_THUMBNAILS: + # Always keep the first thumbnail for the event + self.tracked_events[data["id"]].pop(1) + + self.tracked_events[data["id"]].append(data) + + self.frame_manager.close(frame_name) + def _process_finalized(self) -> None: """Process the end of an event.""" while True: - ended = self.event_end_subscriber.check_for_update(timeout=0.1) + ended = self.event_end_subscriber.check_for_update(timeout=0.01) if ended == None: break @@ -245,7 +265,7 @@ class EmbeddingMaintainer(threading.Thread): def _process_event_metadata(self): # Check for regenerate description requests (topic, event_id, source) = self.event_metadata_subscriber.check_for_update( - timeout=0.1 + timeout=0.01 ) if topic is None: @@ -254,6 +274,94 @@ class EmbeddingMaintainer(threading.Thread): if event_id: self.handle_regenerate_description(event_id, source) + def _search_face(self, query_embedding: bytes) -> list: + """Search for the face most closely matching the embedding.""" + sql_query = """ + SELECT + id, + distance + FROM vec_faces + WHERE face_embedding MATCH ? + AND k = 10 ORDER BY distance + """ + return self.embeddings.db.execute_sql(sql_query, [query_embedding]).fetchall() + + def _process_face(self, obj_data: dict[str, any], frame: np.ndarray) -> None: + """Look for faces in image.""" + # don't run for non person objects + if obj_data.get("label") != "person": + logger.debug("Not a processing face for non person object.") + return + + # don't overwrite sub label for objects that have one + if obj_data.get("sub_label"): + logger.debug( + f"Not processing face due to existing sub label: {obj_data.get('sub_label')}." + ) + return + + face: Optional[dict[str, any]] = None + + if self.requires_face_detection: + # TODO run cv2 face detection + pass + else: + # don't run for object without attributes + if not obj_data.get("current_attributes"): + logger.debug("No attributes to parse.") + return + + attributes: list[dict[str, any]] = obj_data.get("current_attributes", []) + for attr in attributes: + if attr.get("label") != "face": + continue + + if face is None or attr.get("score", 0.0) > face.get("score", 0.0): + face = attr + + # no faces detected in this frame + if not face: + return + + face_box = face.get("box") + + # check that face is valid + if ( + not face_box + or area(face_box) < self.config.semantic_search.face_recognition.min_area + ): + logger.debug(f"Invalid face box {face}") + return + + face_frame = cv2.cvtColor(frame, cv2.COLOR_YUV2BGR_I420) + face_frame = face_frame[face_box[1] : face_box[3], face_box[0] : face_box[2]] + ret, jpg = cv2.imencode( + ".webp", face_frame, [int(cv2.IMWRITE_WEBP_QUALITY), 100] + ) + + if not ret: + logger.debug("Not processing face due to error creating cropped image.") + return + + embedding = self.embeddings.embed_face("unknown", jpg.tobytes(), upsert=False) + query_embedding = serialize(embedding) + best_faces = self._search_face(query_embedding) + logger.debug(f"Detected best faces for person as: {best_faces}") + + if not best_faces: + return + + sub_label = str(best_faces[0][0]).split("-")[0] + score = 1.0 - best_faces[0][1] + + if score < self.config.semantic_search.face_recognition.threshold: + return None + + requests.post( + f"{FRIGATE_LOCALHOST}/api/events/{obj_data['id']}/sub_label", + json={"subLabel": sub_label, "subLabelScore": score}, + ) + def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]: """Return jpg thumbnail of a region of the frame.""" frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420) diff --git a/frigate/util/downloader.py b/frigate/util/downloader.py index 6685b0bb8..18c577fb0 100644 --- a/frigate/util/downloader.py +++ b/frigate/util/downloader.py @@ -101,7 +101,7 @@ class ModelDownloader: self.download_complete.set() @staticmethod - def download_from_url(url: str, save_path: str, silent: bool = False): + def download_from_url(url: str, save_path: str, silent: bool = False) -> Path: temporary_filename = Path(save_path).with_name( os.path.basename(save_path) + ".part" ) @@ -125,6 +125,8 @@ class ModelDownloader: if not silent: logger.info(f"Downloading complete: {url}") + return Path(save_path) + @staticmethod def mark_files_state( requestor: InterProcessRequestor,