mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-21 19:07:46 +01:00
Face recognition backend (#14495)
* Add basic config and face recognition table * Reconfigure updates processing to handle face * Crop frame to face box * Implement face embedding calculation * Get matching face embeddings * Add support face recognition based on existing faces * Use arcface face embeddings instead of generic embeddings model * Add apis for managing faces * Implement face uploading API * Build out more APIs * Add min area config * Handle larger images * Add more debug logs * fix calculation * Reduce timeout * Small tweaks * Use webp images * Use facenet model
This commit is contained in:
parent
3394942e92
commit
b7b7e1b78b
@ -8,6 +8,8 @@ imutils == 0.5.*
|
|||||||
joserfc == 1.0.*
|
joserfc == 1.0.*
|
||||||
pathvalidate == 3.2.*
|
pathvalidate == 3.2.*
|
||||||
markupsafe == 2.1.*
|
markupsafe == 2.1.*
|
||||||
|
python-multipart == 0.0.12
|
||||||
|
# General
|
||||||
mypy == 1.6.1
|
mypy == 1.6.1
|
||||||
numpy == 1.26.*
|
numpy == 1.26.*
|
||||||
onvif_zeep == 0.2.12
|
onvif_zeep == 0.2.12
|
||||||
|
56
frigate/api/classification.py
Normal file
56
frigate/api/classification.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
"""Object classification APIs."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Request, UploadFile
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from frigate.api.defs.tags import Tags
|
||||||
|
from frigate.embeddings import EmbeddingsContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=[Tags.events])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/faces")
|
||||||
|
def get_faces():
|
||||||
|
return JSONResponse(content={"message": "there are faces"})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/faces/{name}")
|
||||||
|
async def register_face(request: Request, name: str, file: UploadFile):
|
||||||
|
# if not file.content_type.startswith("image"):
|
||||||
|
# return JSONResponse(
|
||||||
|
# status_code=400,
|
||||||
|
# content={
|
||||||
|
# "success": False,
|
||||||
|
# "message": "Only an image can be used to register a face.",
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
|
||||||
|
context: EmbeddingsContext = request.app.embeddings
|
||||||
|
context.register_face(name, await file.read())
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={"success": True, "message": "Successfully registered face."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/faces")
|
||||||
|
def deregister_faces(request: Request, body: dict = None):
|
||||||
|
json: dict[str, any] = body or {}
|
||||||
|
list_of_ids = json.get("ids", "")
|
||||||
|
|
||||||
|
if not list_of_ids or len(list_of_ids) == 0:
|
||||||
|
return JSONResponse(
|
||||||
|
content=({"success": False, "message": "Not a valid list of ids"}),
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
|
context: EmbeddingsContext = request.app.embeddings
|
||||||
|
context.delete_face_ids(list_of_ids)
|
||||||
|
return JSONResponse(
|
||||||
|
content=({"success": True, "message": "Successfully deleted faces."}),
|
||||||
|
status_code=200,
|
||||||
|
)
|
@ -10,4 +10,5 @@ class Tags(Enum):
|
|||||||
review = "Review"
|
review = "Review"
|
||||||
export = "Export"
|
export = "Export"
|
||||||
events = "Events"
|
events = "Events"
|
||||||
|
classification = "classification"
|
||||||
auth = "Auth"
|
auth = "Auth"
|
||||||
|
@ -11,7 +11,16 @@ from starlette_context import middleware, plugins
|
|||||||
from starlette_context.plugins import Plugin
|
from starlette_context.plugins import Plugin
|
||||||
|
|
||||||
from frigate.api import app as main_app
|
from frigate.api import app as main_app
|
||||||
from frigate.api import auth, event, export, media, notification, preview, review
|
from frigate.api import (
|
||||||
|
auth,
|
||||||
|
classification,
|
||||||
|
event,
|
||||||
|
export,
|
||||||
|
media,
|
||||||
|
notification,
|
||||||
|
preview,
|
||||||
|
review,
|
||||||
|
)
|
||||||
from frigate.api.auth import get_jwt_secret, limiter
|
from frigate.api.auth import get_jwt_secret, limiter
|
||||||
from frigate.comms.event_metadata_updater import (
|
from frigate.comms.event_metadata_updater import (
|
||||||
EventMetadataPublisher,
|
EventMetadataPublisher,
|
||||||
@ -95,6 +104,7 @@ def create_fastapi_app(
|
|||||||
# Routes
|
# Routes
|
||||||
# Order of include_router matters: https://fastapi.tiangolo.com/tutorial/path-params/#order-matters
|
# Order of include_router matters: https://fastapi.tiangolo.com/tutorial/path-params/#order-matters
|
||||||
app.include_router(auth.router)
|
app.include_router(auth.router)
|
||||||
|
app.include_router(classification.router)
|
||||||
app.include_router(review.router)
|
app.include_router(review.router)
|
||||||
app.include_router(main_app.router)
|
app.include_router(main_app.router)
|
||||||
app.include_router(preview.router)
|
app.include_router(preview.router)
|
||||||
|
@ -12,6 +12,7 @@ class EmbeddingsRequestEnum(Enum):
|
|||||||
embed_description = "embed_description"
|
embed_description = "embed_description"
|
||||||
embed_thumbnail = "embed_thumbnail"
|
embed_thumbnail = "embed_thumbnail"
|
||||||
generate_search = "generate_search"
|
generate_search = "generate_search"
|
||||||
|
register_face = "register_face"
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsResponder:
|
class EmbeddingsResponder:
|
||||||
@ -22,7 +23,7 @@ class EmbeddingsResponder:
|
|||||||
|
|
||||||
def check_for_request(self, process: Callable) -> None:
|
def check_for_request(self, process: Callable) -> None:
|
||||||
while True: # load all messages that are queued
|
while True: # load all messages that are queued
|
||||||
has_message, _, _ = zmq.select([self.socket], [], [], 0.1)
|
has_message, _, _ = zmq.select([self.socket], [], [], 0.01)
|
||||||
|
|
||||||
if not has_message:
|
if not has_message:
|
||||||
break
|
break
|
||||||
|
@ -4,7 +4,17 @@ from pydantic import Field
|
|||||||
|
|
||||||
from .base import FrigateBaseModel
|
from .base import FrigateBaseModel
|
||||||
|
|
||||||
__all__ = ["SemanticSearchConfig"]
|
__all__ = ["FaceRecognitionConfig", "SemanticSearchConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
class FaceRecognitionConfig(FrigateBaseModel):
|
||||||
|
enabled: bool = Field(default=False, title="Enable face recognition.")
|
||||||
|
threshold: float = Field(
|
||||||
|
default=0.9, title="Face similarity score required to be considered a match."
|
||||||
|
)
|
||||||
|
min_area: int = Field(
|
||||||
|
default=500, title="Min area of face box to consider running face recognition."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SemanticSearchConfig(FrigateBaseModel):
|
class SemanticSearchConfig(FrigateBaseModel):
|
||||||
@ -12,6 +22,9 @@ class SemanticSearchConfig(FrigateBaseModel):
|
|||||||
reindex: Optional[bool] = Field(
|
reindex: Optional[bool] = Field(
|
||||||
default=False, title="Reindex all detections on startup."
|
default=False, title="Reindex all detections on startup."
|
||||||
)
|
)
|
||||||
|
face_recognition: FaceRecognitionConfig = Field(
|
||||||
|
default_factory=FaceRecognitionConfig, title="Face recognition config."
|
||||||
|
)
|
||||||
model_size: str = Field(
|
model_size: str = Field(
|
||||||
default="small", title="The size of the embeddings model used."
|
default="small", title="The size of the embeddings model used."
|
||||||
)
|
)
|
||||||
|
@ -5,8 +5,9 @@ DEFAULT_DB_PATH = f"{CONFIG_DIR}/frigate.db"
|
|||||||
MODEL_CACHE_DIR = f"{CONFIG_DIR}/model_cache"
|
MODEL_CACHE_DIR = f"{CONFIG_DIR}/model_cache"
|
||||||
BASE_DIR = "/media/frigate"
|
BASE_DIR = "/media/frigate"
|
||||||
CLIPS_DIR = f"{BASE_DIR}/clips"
|
CLIPS_DIR = f"{BASE_DIR}/clips"
|
||||||
RECORD_DIR = f"{BASE_DIR}/recordings"
|
|
||||||
EXPORT_DIR = f"{BASE_DIR}/exports"
|
EXPORT_DIR = f"{BASE_DIR}/exports"
|
||||||
|
FACE_DIR = f"{CLIPS_DIR}/faces"
|
||||||
|
RECORD_DIR = f"{BASE_DIR}/recordings"
|
||||||
BIRDSEYE_PIPE = "/tmp/cache/birdseye"
|
BIRDSEYE_PIPE = "/tmp/cache/birdseye"
|
||||||
CACHE_DIR = "/tmp/cache"
|
CACHE_DIR = "/tmp/cache"
|
||||||
FRIGATE_LOCALHOST = "http://127.0.0.1:5000"
|
FRIGATE_LOCALHOST = "http://127.0.0.1:5000"
|
||||||
|
@ -29,6 +29,10 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
|
|||||||
ids = ",".join(["?" for _ in event_ids])
|
ids = ",".join(["?" for _ in event_ids])
|
||||||
self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids)
|
self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids)
|
||||||
|
|
||||||
|
def delete_embeddings_face(self, face_ids: list[str]) -> None:
|
||||||
|
ids = ",".join(["?" for _ in face_ids])
|
||||||
|
self.execute_sql(f"DELETE FROM vec_faces WHERE id IN ({ids})", face_ids)
|
||||||
|
|
||||||
def drop_embeddings_tables(self) -> None:
|
def drop_embeddings_tables(self) -> None:
|
||||||
self.execute_sql("""
|
self.execute_sql("""
|
||||||
DROP TABLE vec_descriptions;
|
DROP TABLE vec_descriptions;
|
||||||
@ -36,8 +40,11 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
|
|||||||
self.execute_sql("""
|
self.execute_sql("""
|
||||||
DROP TABLE vec_thumbnails;
|
DROP TABLE vec_thumbnails;
|
||||||
""")
|
""")
|
||||||
|
self.execute_sql("""
|
||||||
|
DROP TABLE vec_faces;
|
||||||
|
""")
|
||||||
|
|
||||||
def create_embeddings_tables(self) -> None:
|
def create_embeddings_tables(self, face_recognition: bool) -> None:
|
||||||
"""Create vec0 virtual table for embeddings"""
|
"""Create vec0 virtual table for embeddings"""
|
||||||
self.execute_sql("""
|
self.execute_sql("""
|
||||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
|
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
|
||||||
@ -51,3 +58,11 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
|
|||||||
description_embedding FLOAT[768] distance_metric=cosine
|
description_embedding FLOAT[768] distance_metric=cosine
|
||||||
);
|
);
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
if face_recognition:
|
||||||
|
self.execute_sql("""
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS vec_faces USING vec0(
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
face_embedding FLOAT[128] distance_metric=cosine
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""SQLite-vec embeddings database."""
|
"""SQLite-vec embeddings database."""
|
||||||
|
|
||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
@ -189,6 +190,28 @@ class EmbeddingsContext:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def register_face(self, face_name: str, image_data: bytes) -> None:
|
||||||
|
self.requestor.send_data(
|
||||||
|
EmbeddingsRequestEnum.register_face.value,
|
||||||
|
{
|
||||||
|
"face_name": face_name,
|
||||||
|
"image": base64.b64encode(image_data).decode("ASCII"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_face_ids(self, name: str) -> list[str]:
|
||||||
|
sql_query = f"""
|
||||||
|
SELECT
|
||||||
|
id
|
||||||
|
FROM vec_descriptions
|
||||||
|
WHERE id LIKE '%{name}%'
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.db.execute_sql(sql_query).fetchall()
|
||||||
|
|
||||||
|
def delete_face_ids(self, ids: list[str]) -> None:
|
||||||
|
self.db.delete_embeddings_face(ids)
|
||||||
|
|
||||||
def update_description(self, event_id: str, description: str) -> None:
|
def update_description(self, event_id: str, description: str) -> None:
|
||||||
self.requestor.send_data(
|
self.requestor.send_data(
|
||||||
EmbeddingsRequestEnum.embed_description.value,
|
EmbeddingsRequestEnum.embed_description.value,
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
import string
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from numpy import ndarray
|
from numpy import ndarray
|
||||||
@ -12,6 +14,7 @@ from frigate.comms.inter_process import InterProcessRequestor
|
|||||||
from frigate.config.semantic_search import SemanticSearchConfig
|
from frigate.config.semantic_search import SemanticSearchConfig
|
||||||
from frigate.const import (
|
from frigate.const import (
|
||||||
CONFIG_DIR,
|
CONFIG_DIR,
|
||||||
|
FACE_DIR,
|
||||||
UPDATE_EMBEDDINGS_REINDEX_PROGRESS,
|
UPDATE_EMBEDDINGS_REINDEX_PROGRESS,
|
||||||
UPDATE_MODEL_STATE,
|
UPDATE_MODEL_STATE,
|
||||||
)
|
)
|
||||||
@ -67,7 +70,7 @@ class Embeddings:
|
|||||||
self.requestor = InterProcessRequestor()
|
self.requestor = InterProcessRequestor()
|
||||||
|
|
||||||
# Create tables if they don't exist
|
# Create tables if they don't exist
|
||||||
self.db.create_embeddings_tables()
|
self.db.create_embeddings_tables(self.config.face_recognition.enabled)
|
||||||
|
|
||||||
models = [
|
models = [
|
||||||
"jinaai/jina-clip-v1-text_model_fp16.onnx",
|
"jinaai/jina-clip-v1-text_model_fp16.onnx",
|
||||||
@ -121,6 +124,21 @@ class Embeddings:
|
|||||||
device="GPU" if config.model_size == "large" else "CPU",
|
device="GPU" if config.model_size == "large" else "CPU",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.face_embedding = None
|
||||||
|
|
||||||
|
if self.config.face_recognition.enabled:
|
||||||
|
self.face_embedding = GenericONNXEmbedding(
|
||||||
|
model_name="facenet",
|
||||||
|
model_file="facenet.onnx",
|
||||||
|
download_urls={
|
||||||
|
"facenet.onnx": "https://github.com/NicolasSM-001/faceNet.onnx-/raw/refs/heads/main/faceNet.onnx"
|
||||||
|
},
|
||||||
|
model_size="large",
|
||||||
|
model_type=ModelTypeEnum.face,
|
||||||
|
requestor=self.requestor,
|
||||||
|
device="GPU",
|
||||||
|
)
|
||||||
|
|
||||||
def embed_thumbnail(
|
def embed_thumbnail(
|
||||||
self, event_id: str, thumbnail: bytes, upsert: bool = True
|
self, event_id: str, thumbnail: bytes, upsert: bool = True
|
||||||
) -> ndarray:
|
) -> ndarray:
|
||||||
@ -215,12 +233,40 @@ class Embeddings:
|
|||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
def embed_face(self, label: str, thumbnail: bytes, upsert: bool = False) -> ndarray:
|
||||||
|
embedding = self.face_embedding(thumbnail)[0]
|
||||||
|
|
||||||
|
if upsert:
|
||||||
|
rand_id = "".join(
|
||||||
|
random.choices(string.ascii_lowercase + string.digits, k=6)
|
||||||
|
)
|
||||||
|
id = f"{label}-{rand_id}"
|
||||||
|
|
||||||
|
# write face to library
|
||||||
|
folder = os.path.join(FACE_DIR, label)
|
||||||
|
file = os.path.join(folder, f"{id}.webp")
|
||||||
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
|
||||||
|
# save face image
|
||||||
|
with open(file, "wb") as output:
|
||||||
|
output.write(thumbnail)
|
||||||
|
|
||||||
|
self.db.execute_sql(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO vec_faces(id, face_embedding)
|
||||||
|
VALUES(?, ?)
|
||||||
|
""",
|
||||||
|
(id, serialize(embedding)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
def reindex(self) -> None:
|
def reindex(self) -> None:
|
||||||
logger.info("Indexing tracked object embeddings...")
|
logger.info("Indexing tracked object embeddings...")
|
||||||
|
|
||||||
self.db.drop_embeddings_tables()
|
self.db.drop_embeddings_tables()
|
||||||
logger.debug("Dropped embeddings tables.")
|
logger.debug("Dropped embeddings tables.")
|
||||||
self.db.create_embeddings_tables()
|
self.db.create_embeddings_tables(self.config.face_recognition.enabled)
|
||||||
logger.debug("Created embeddings tables.")
|
logger.debug("Created embeddings tables.")
|
||||||
|
|
||||||
# Delete the saved stats file
|
# Delete the saved stats file
|
||||||
|
@ -31,6 +31,8 @@ warnings.filterwarnings(
|
|||||||
disable_progress_bar()
|
disable_progress_bar()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FACE_EMBEDDING_SIZE = 160
|
||||||
|
|
||||||
|
|
||||||
class ModelTypeEnum(str, Enum):
|
class ModelTypeEnum(str, Enum):
|
||||||
face = "face"
|
face = "face"
|
||||||
@ -47,7 +49,7 @@ class GenericONNXEmbedding:
|
|||||||
model_file: str,
|
model_file: str,
|
||||||
download_urls: Dict[str, str],
|
download_urls: Dict[str, str],
|
||||||
model_size: str,
|
model_size: str,
|
||||||
model_type: str,
|
model_type: ModelTypeEnum,
|
||||||
requestor: InterProcessRequestor,
|
requestor: InterProcessRequestor,
|
||||||
tokenizer_file: Optional[str] = None,
|
tokenizer_file: Optional[str] = None,
|
||||||
device: str = "AUTO",
|
device: str = "AUTO",
|
||||||
@ -57,7 +59,7 @@ class GenericONNXEmbedding:
|
|||||||
self.tokenizer_file = tokenizer_file
|
self.tokenizer_file = tokenizer_file
|
||||||
self.requestor = requestor
|
self.requestor = requestor
|
||||||
self.download_urls = download_urls
|
self.download_urls = download_urls
|
||||||
self.model_type = model_type # 'text' or 'vision'
|
self.model_type = model_type
|
||||||
self.model_size = model_size
|
self.model_size = model_size
|
||||||
self.device = device
|
self.device = device
|
||||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||||
@ -93,6 +95,7 @@ class GenericONNXEmbedding:
|
|||||||
def _download_model(self, path: str):
|
def _download_model(self, path: str):
|
||||||
try:
|
try:
|
||||||
file_name = os.path.basename(path)
|
file_name = os.path.basename(path)
|
||||||
|
|
||||||
if file_name in self.download_urls:
|
if file_name in self.download_urls:
|
||||||
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
||||||
elif (
|
elif (
|
||||||
@ -101,6 +104,7 @@ class GenericONNXEmbedding:
|
|||||||
):
|
):
|
||||||
if not os.path.exists(path + "/" + self.model_name):
|
if not os.path.exists(path + "/" + self.model_name):
|
||||||
logger.info(f"Downloading {self.model_name} tokenizer")
|
logger.info(f"Downloading {self.model_name} tokenizer")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
@ -131,8 +135,11 @@ class GenericONNXEmbedding:
|
|||||||
self.downloader.wait_for_download()
|
self.downloader.wait_for_download()
|
||||||
if self.model_type == ModelTypeEnum.text:
|
if self.model_type == ModelTypeEnum.text:
|
||||||
self.tokenizer = self._load_tokenizer()
|
self.tokenizer = self._load_tokenizer()
|
||||||
else:
|
elif self.model_type == ModelTypeEnum.vision:
|
||||||
self.feature_extractor = self._load_feature_extractor()
|
self.feature_extractor = self._load_feature_extractor()
|
||||||
|
elif self.model_type == ModelTypeEnum.face:
|
||||||
|
self.feature_extractor = []
|
||||||
|
|
||||||
self.runner = ONNXModelRunner(
|
self.runner = ONNXModelRunner(
|
||||||
os.path.join(self.download_path, self.model_file),
|
os.path.join(self.download_path, self.model_file),
|
||||||
self.device,
|
self.device,
|
||||||
@ -172,16 +179,51 @@ class GenericONNXEmbedding:
|
|||||||
self.feature_extractor(images=image, return_tensors="np")
|
self.feature_extractor(images=image, return_tensors="np")
|
||||||
for image in processed_images
|
for image in processed_images
|
||||||
]
|
]
|
||||||
|
elif self.model_type == ModelTypeEnum.face:
|
||||||
|
if isinstance(raw_inputs, list):
|
||||||
|
raise ValueError("Face embedding does not support batch inputs.")
|
||||||
|
|
||||||
|
pil = self._process_image(raw_inputs)
|
||||||
|
|
||||||
|
# handle images larger than input size
|
||||||
|
width, height = pil.size
|
||||||
|
if width != FACE_EMBEDDING_SIZE or height != FACE_EMBEDDING_SIZE:
|
||||||
|
if width > height:
|
||||||
|
new_height = int(((height / width) * FACE_EMBEDDING_SIZE) // 4 * 4)
|
||||||
|
pil = pil.resize((FACE_EMBEDDING_SIZE, new_height))
|
||||||
|
else:
|
||||||
|
new_width = int(((width / height) * FACE_EMBEDDING_SIZE) // 4 * 4)
|
||||||
|
pil = pil.resize((new_width, FACE_EMBEDDING_SIZE))
|
||||||
|
|
||||||
|
og = np.array(pil).astype(np.float32)
|
||||||
|
|
||||||
|
# Image must be FACE_EMBEDDING_SIZExFACE_EMBEDDING_SIZE
|
||||||
|
og_h, og_w, channels = og.shape
|
||||||
|
frame = np.full(
|
||||||
|
(FACE_EMBEDDING_SIZE, FACE_EMBEDDING_SIZE, channels),
|
||||||
|
(0, 0, 0),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute center offset
|
||||||
|
x_center = (FACE_EMBEDDING_SIZE - og_w) // 2
|
||||||
|
y_center = (FACE_EMBEDDING_SIZE - og_h) // 2
|
||||||
|
|
||||||
|
# copy img image into center of result image
|
||||||
|
frame[y_center : y_center + og_h, x_center : x_center + og_w] = og
|
||||||
|
|
||||||
|
frame = np.expand_dims(frame, axis=0)
|
||||||
|
return [{"image_input": frame}]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unable to preprocess inputs for {self.model_type}")
|
raise ValueError(f"Unable to preprocess inputs for {self.model_type}")
|
||||||
|
|
||||||
def _process_image(self, image):
|
def _process_image(self, image, output: str = "RGB") -> Image.Image:
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
if image.startswith("http"):
|
if image.startswith("http"):
|
||||||
response = requests.get(image)
|
response = requests.get(image)
|
||||||
image = Image.open(BytesIO(response.content)).convert("RGB")
|
image = Image.open(BytesIO(response.content)).convert(output)
|
||||||
elif isinstance(image, bytes):
|
elif isinstance(image, bytes):
|
||||||
image = Image.open(BytesIO(image)).convert("RGB")
|
image = Image.open(BytesIO(image)).convert(output)
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import requests
|
||||||
from peewee import DoesNotExist
|
from peewee import DoesNotExist
|
||||||
from playhouse.sqliteq import SqliteQueueDatabase
|
from playhouse.sqliteq import SqliteQueueDatabase
|
||||||
|
|
||||||
@ -20,12 +21,12 @@ from frigate.comms.event_metadata_updater import (
|
|||||||
from frigate.comms.events_updater import EventEndSubscriber, EventUpdateSubscriber
|
from frigate.comms.events_updater import EventEndSubscriber, EventUpdateSubscriber
|
||||||
from frigate.comms.inter_process import InterProcessRequestor
|
from frigate.comms.inter_process import InterProcessRequestor
|
||||||
from frigate.config import FrigateConfig
|
from frigate.config import FrigateConfig
|
||||||
from frigate.const import CLIPS_DIR, UPDATE_EVENT_DESCRIPTION
|
from frigate.const import CLIPS_DIR, FRIGATE_LOCALHOST, UPDATE_EVENT_DESCRIPTION
|
||||||
from frigate.events.types import EventTypeEnum
|
from frigate.events.types import EventTypeEnum
|
||||||
from frigate.genai import get_genai_client
|
from frigate.genai import get_genai_client
|
||||||
from frigate.models import Event
|
from frigate.models import Event
|
||||||
from frigate.util.builtin import serialize
|
from frigate.util.builtin import serialize
|
||||||
from frigate.util.image import SharedMemoryFrameManager, calculate_region
|
from frigate.util.image import SharedMemoryFrameManager, area, calculate_region
|
||||||
|
|
||||||
from .embeddings import Embeddings
|
from .embeddings import Embeddings
|
||||||
|
|
||||||
@ -58,10 +59,17 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
)
|
)
|
||||||
self.embeddings_responder = EmbeddingsResponder()
|
self.embeddings_responder = EmbeddingsResponder()
|
||||||
self.frame_manager = SharedMemoryFrameManager()
|
self.frame_manager = SharedMemoryFrameManager()
|
||||||
|
|
||||||
|
# set face recognition conditions
|
||||||
|
self.face_recognition_enabled = (
|
||||||
|
self.config.semantic_search.face_recognition.enabled
|
||||||
|
)
|
||||||
|
self.requires_face_detection = "face" not in self.config.model.all_attributes
|
||||||
|
|
||||||
# create communication for updating event descriptions
|
# create communication for updating event descriptions
|
||||||
self.requestor = InterProcessRequestor()
|
self.requestor = InterProcessRequestor()
|
||||||
self.stop_event = stop_event
|
self.stop_event = stop_event
|
||||||
self.tracked_events = {}
|
self.tracked_events: dict[str, list[any]] = {}
|
||||||
self.genai_client = get_genai_client(config)
|
self.genai_client = get_genai_client(config)
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
@ -101,6 +109,13 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
return serialize(
|
return serialize(
|
||||||
self.embeddings.text_embedding([data])[0], pack=False
|
self.embeddings.text_embedding([data])[0], pack=False
|
||||||
)
|
)
|
||||||
|
elif topic == EmbeddingsRequestEnum.register_face.value:
|
||||||
|
self.embeddings.embed_face(
|
||||||
|
data["face_name"],
|
||||||
|
base64.b64decode(data["image"]),
|
||||||
|
upsert=True,
|
||||||
|
)
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unable to handle embeddings request {e}")
|
logger.error(f"Unable to handle embeddings request {e}")
|
||||||
|
|
||||||
@ -108,7 +123,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
|
|
||||||
def _process_updates(self) -> None:
|
def _process_updates(self) -> None:
|
||||||
"""Process event updates"""
|
"""Process event updates"""
|
||||||
update = self.event_subscriber.check_for_update(timeout=0.1)
|
update = self.event_subscriber.check_for_update(timeout=0.01)
|
||||||
|
|
||||||
if update is None:
|
if update is None:
|
||||||
return
|
return
|
||||||
@ -119,24 +134,33 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
return
|
return
|
||||||
|
|
||||||
camera_config = self.config.cameras[camera]
|
camera_config = self.config.cameras[camera]
|
||||||
# no need to save our own thumbnails if genai is not enabled
|
|
||||||
# or if the object has become stationary
|
|
||||||
if (
|
|
||||||
not camera_config.genai.enabled
|
|
||||||
or self.genai_client is None
|
|
||||||
or data["stationary"]
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
if data["id"] not in self.tracked_events:
|
# no need to process updated objects if face recognition and genai are disabled
|
||||||
self.tracked_events[data["id"]] = []
|
if not camera_config.genai.enabled and not self.face_recognition_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
# Create our own thumbnail based on the bounding box and the frame time
|
# Create our own thumbnail based on the bounding box and the frame time
|
||||||
try:
|
try:
|
||||||
frame_id = f"{camera}{data['frame_time']}"
|
frame_id = f"{camera}{data['frame_time']}"
|
||||||
yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv)
|
yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if yuv_frame is None:
|
||||||
|
logger.debug(
|
||||||
|
"Unable to process object update because frame is unavailable."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.face_recognition_enabled:
|
||||||
|
self._process_face(data, yuv_frame)
|
||||||
|
|
||||||
|
# no need to save our own thumbnails if genai is not enabled
|
||||||
|
# or if the object has become stationary
|
||||||
|
if self.genai_client is not None and not data["stationary"]:
|
||||||
|
if data["id"] not in self.tracked_events:
|
||||||
|
self.tracked_events[data["id"]] = []
|
||||||
|
|
||||||
if yuv_frame is not None:
|
|
||||||
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
|
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
|
||||||
|
|
||||||
# Limit the number of thumbnails saved
|
# Limit the number of thumbnails saved
|
||||||
@ -147,13 +171,11 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
self.tracked_events[data["id"]].append(data)
|
self.tracked_events[data["id"]].append(data)
|
||||||
|
|
||||||
self.frame_manager.close(frame_id)
|
self.frame_manager.close(frame_id)
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _process_finalized(self) -> None:
|
def _process_finalized(self) -> None:
|
||||||
"""Process the end of an event."""
|
"""Process the end of an event."""
|
||||||
while True:
|
while True:
|
||||||
ended = self.event_end_subscriber.check_for_update(timeout=0.1)
|
ended = self.event_end_subscriber.check_for_update(timeout=0.01)
|
||||||
|
|
||||||
if ended == None:
|
if ended == None:
|
||||||
break
|
break
|
||||||
@ -243,7 +265,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
def _process_event_metadata(self):
|
def _process_event_metadata(self):
|
||||||
# Check for regenerate description requests
|
# Check for regenerate description requests
|
||||||
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update(
|
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update(
|
||||||
timeout=0.1
|
timeout=0.01
|
||||||
)
|
)
|
||||||
|
|
||||||
if topic is None:
|
if topic is None:
|
||||||
@ -252,6 +274,94 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
if event_id:
|
if event_id:
|
||||||
self.handle_regenerate_description(event_id, source)
|
self.handle_regenerate_description(event_id, source)
|
||||||
|
|
||||||
|
def _search_face(self, query_embedding: bytes) -> list:
|
||||||
|
"""Search for the face most closely matching the embedding."""
|
||||||
|
sql_query = """
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
distance
|
||||||
|
FROM vec_faces
|
||||||
|
WHERE face_embedding MATCH ?
|
||||||
|
AND k = 10 ORDER BY distance
|
||||||
|
"""
|
||||||
|
return self.embeddings.db.execute_sql(sql_query, [query_embedding]).fetchall()
|
||||||
|
|
||||||
|
def _process_face(self, obj_data: dict[str, any], frame: np.ndarray) -> None:
|
||||||
|
"""Look for faces in image."""
|
||||||
|
# don't run for non person objects
|
||||||
|
if obj_data.get("label") != "person":
|
||||||
|
logger.debug("Not a processing face for non person object.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# don't overwrite sub label for objects that have one
|
||||||
|
if obj_data.get("sub_label"):
|
||||||
|
logger.debug(
|
||||||
|
f"Not processing face due to existing sub label: {obj_data.get('sub_label')}."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
face: Optional[dict[str, any]] = None
|
||||||
|
|
||||||
|
if self.requires_face_detection:
|
||||||
|
# TODO run cv2 face detection
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# don't run for object without attributes
|
||||||
|
if not obj_data.get("current_attributes"):
|
||||||
|
logger.debug("No attributes to parse.")
|
||||||
|
return
|
||||||
|
|
||||||
|
attributes: list[dict[str, any]] = obj_data.get("current_attributes", [])
|
||||||
|
for attr in attributes:
|
||||||
|
if attr.get("label") != "face":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if face is None or attr.get("score", 0.0) > face.get("score", 0.0):
|
||||||
|
face = attr
|
||||||
|
|
||||||
|
# no faces detected in this frame
|
||||||
|
if not face:
|
||||||
|
return
|
||||||
|
|
||||||
|
face_box = face.get("box")
|
||||||
|
|
||||||
|
# check that face is valid
|
||||||
|
if (
|
||||||
|
not face_box
|
||||||
|
or area(face_box) < self.config.semantic_search.face_recognition.min_area
|
||||||
|
):
|
||||||
|
logger.debug(f"Invalid face box {face}")
|
||||||
|
return
|
||||||
|
|
||||||
|
face_frame = cv2.cvtColor(frame, cv2.COLOR_YUV2BGR_I420)
|
||||||
|
face_frame = face_frame[face_box[1] : face_box[3], face_box[0] : face_box[2]]
|
||||||
|
ret, jpg = cv2.imencode(
|
||||||
|
".webp", face_frame, [int(cv2.IMWRITE_WEBP_QUALITY), 100]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not ret:
|
||||||
|
logger.debug("Not processing face due to error creating cropped image.")
|
||||||
|
return
|
||||||
|
|
||||||
|
embedding = self.embeddings.embed_face("unknown", jpg.tobytes(), upsert=False)
|
||||||
|
query_embedding = serialize(embedding)
|
||||||
|
best_faces = self._search_face(query_embedding)
|
||||||
|
logger.debug(f"Detected best faces for person as: {best_faces}")
|
||||||
|
|
||||||
|
if not best_faces:
|
||||||
|
return
|
||||||
|
|
||||||
|
sub_label = str(best_faces[0][0]).split("-")[0]
|
||||||
|
score = 1.0 - best_faces[0][1]
|
||||||
|
|
||||||
|
if score < self.config.semantic_search.face_recognition.threshold:
|
||||||
|
return None
|
||||||
|
|
||||||
|
requests.post(
|
||||||
|
f"{FRIGATE_LOCALHOST}/api/events/{obj_data['id']}/sub_label",
|
||||||
|
json={"subLabel": sub_label, "subLabelScore": score},
|
||||||
|
)
|
||||||
|
|
||||||
def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]:
|
def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]:
|
||||||
"""Return jpg thumbnail of a region of the frame."""
|
"""Return jpg thumbnail of a region of the frame."""
|
||||||
frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420)
|
frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420)
|
||||||
|
@ -101,7 +101,7 @@ class ModelDownloader:
|
|||||||
self.download_complete.set()
|
self.download_complete.set()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def download_from_url(url: str, save_path: str, silent: bool = False):
|
def download_from_url(url: str, save_path: str, silent: bool = False) -> Path:
|
||||||
temporary_filename = Path(save_path).with_name(
|
temporary_filename = Path(save_path).with_name(
|
||||||
os.path.basename(save_path) + ".part"
|
os.path.basename(save_path) + ".part"
|
||||||
)
|
)
|
||||||
@ -125,6 +125,8 @@ class ModelDownloader:
|
|||||||
if not silent:
|
if not silent:
|
||||||
logger.info(f"Downloading complete: {url}")
|
logger.info(f"Downloading complete: {url}")
|
||||||
|
|
||||||
|
return Path(save_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mark_files_state(
|
def mark_files_state(
|
||||||
requestor: InterProcessRequestor,
|
requestor: InterProcessRequestor,
|
||||||
|
Loading…
Reference in New Issue
Block a user