blakeblackshear.frigate/frigate/embeddings/embeddings.py

417 lines
14 KiB
Python
Raw Normal View History

Use sqlite-vec extension instead of chromadb for embeddings (#14163) * swap sqlite_vec for chroma in requirements * load sqlite_vec in embeddings manager * remove chroma and revamp Embeddings class for sqlite_vec * manual minilm onnx inference * remove chroma in clip model * migrate api from chroma to sqlite_vec * migrate event cleanup from chroma to sqlite_vec * migrate embedding maintainer from chroma to sqlite_vec * genai description for sqlite_vec * load sqlite_vec in main thread db * extend the SqliteQueueDatabase class and use peewee db.execute_sql * search with Event type for similarity * fix similarity search * install and add comment about transformers * fix normalization * add id filter * clean up * clean up * fully remove chroma and add transformers env var * readd uvicorn for fastapi * readd tokenizer parallelism env var * remove chroma from docs * remove chroma from UI * try removing custom pysqlite3 build * hard code limit * optimize queries * revert explore query * fix query * keep building pysqlite3 * single pass fetch and process * remove unnecessary re-embed * update deps * move SqliteVecQueueDatabase to db directory * make search thumbnail take up full size of results box * improve typing * improve model downloading and add status screen * daemon downloading thread * catch case when semantic search is disabled * fix typing * build sqlite_vec from source * resolve conflict * file permissions * try build deps * remove sources * sources * fix thread start * include git in build * reorder embeddings after detectors are started * build with sqlite amalgamation * non-platform specific * use wget instead of curl * remove unzip -d * remove sqlite_vec from requirements and load the compiled version * fix build * avoid race in db connection * add scale_factor and bias to description zscore normalization
2024-10-07 22:30:45 +02:00
"""SQLite-vec embeddings database."""
import base64
import logging
import os
import random
import string
import time
from numpy import ndarray
from playhouse.shortcuts import model_to_dict
Use sqlite-vec extension instead of chromadb for embeddings (#14163) * swap sqlite_vec for chroma in requirements * load sqlite_vec in embeddings manager * remove chroma and revamp Embeddings class for sqlite_vec * manual minilm onnx inference * remove chroma in clip model * migrate api from chroma to sqlite_vec * migrate event cleanup from chroma to sqlite_vec * migrate embedding maintainer from chroma to sqlite_vec * genai description for sqlite_vec * load sqlite_vec in main thread db * extend the SqliteQueueDatabase class and use peewee db.execute_sql * search with Event type for similarity * fix similarity search * install and add comment about transformers * fix normalization * add id filter * clean up * clean up * fully remove chroma and add transformers env var * readd uvicorn for fastapi * readd tokenizer parallelism env var * remove chroma from docs * remove chroma from UI * try removing custom pysqlite3 build * hard code limit * optimize queries * revert explore query * fix query * keep building pysqlite3 * single pass fetch and process * remove unnecessary re-embed * update deps * move SqliteVecQueueDatabase to db directory * make search thumbnail take up full size of results box * improve typing * improve model downloading and add status screen * daemon downloading thread * catch case when semantic search is disabled * fix typing * build sqlite_vec from source * resolve conflict * file permissions * try build deps * remove sources * sources * fix thread start * include git in build * reorder embeddings after detectors are started * build with sqlite amalgamation * non-platform specific * use wget instead of curl * remove unzip -d * remove sqlite_vec from requirements and load the compiled version * fix build * avoid race in db connection * add scale_factor and bias to description zscore normalization
2024-10-07 22:30:45 +02:00
from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig
from frigate.const import (
CONFIG_DIR,
FACE_DIR,
UPDATE_EMBEDDINGS_REINDEX_PROGRESS,
UPDATE_MODEL_STATE,
)
Use sqlite-vec extension instead of chromadb for embeddings (#14163) * swap sqlite_vec for chroma in requirements * load sqlite_vec in embeddings manager * remove chroma and revamp Embeddings class for sqlite_vec * manual minilm onnx inference * remove chroma in clip model * migrate api from chroma to sqlite_vec * migrate event cleanup from chroma to sqlite_vec * migrate embedding maintainer from chroma to sqlite_vec * genai description for sqlite_vec * load sqlite_vec in main thread db * extend the SqliteQueueDatabase class and use peewee db.execute_sql * search with Event type for similarity * fix similarity search * install and add comment about transformers * fix normalization * add id filter * clean up * clean up * fully remove chroma and add transformers env var * readd uvicorn for fastapi * readd tokenizer parallelism env var * remove chroma from docs * remove chroma from UI * try removing custom pysqlite3 build * hard code limit * optimize queries * revert explore query * fix query * keep building pysqlite3 * single pass fetch and process * remove unnecessary re-embed * update deps * move SqliteVecQueueDatabase to db directory * make search thumbnail take up full size of results box * improve typing * improve model downloading and add status screen * daemon downloading thread * catch case when semantic search is disabled * fix typing * build sqlite_vec from source * resolve conflict * file permissions * try build deps * remove sources * sources * fix thread start * include git in build * reorder embeddings after detectors are started * build with sqlite amalgamation * non-platform specific * use wget instead of curl * remove unzip -d * remove sqlite_vec from requirements and load the compiled version * fix build * avoid race in db connection * add scale_factor and bias to description zscore normalization
2024-10-07 22:30:45 +02:00
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event
Use sqlite-vec extension instead of chromadb for embeddings (#14163) * swap sqlite_vec for chroma in requirements * load sqlite_vec in embeddings manager * remove chroma and revamp Embeddings class for sqlite_vec * manual minilm onnx inference * remove chroma in clip model * migrate api from chroma to sqlite_vec * migrate event cleanup from chroma to sqlite_vec * migrate embedding maintainer from chroma to sqlite_vec * genai description for sqlite_vec * load sqlite_vec in main thread db * extend the SqliteQueueDatabase class and use peewee db.execute_sql * search with Event type for similarity * fix similarity search * install and add comment about transformers * fix normalization * add id filter * clean up * clean up * fully remove chroma and add transformers env var * readd uvicorn for fastapi * readd tokenizer parallelism env var * remove chroma from docs * remove chroma from UI * try removing custom pysqlite3 build * hard code limit * optimize queries * revert explore query * fix query * keep building pysqlite3 * single pass fetch and process * remove unnecessary re-embed * update deps * move SqliteVecQueueDatabase to db directory * make search thumbnail take up full size of results box * improve typing * improve model downloading and add status screen * daemon downloading thread * catch case when semantic search is disabled * fix typing * build sqlite_vec from source * resolve conflict * file permissions * try build deps * remove sources * sources * fix thread start * include git in build * reorder embeddings after detectors are started * build with sqlite amalgamation * non-platform specific * use wget instead of curl * remove unzip -d * remove sqlite_vec from requirements and load the compiled version * fix build * avoid race in db connection * add scale_factor and bias to description zscore normalization
2024-10-07 22:30:45 +02:00
from frigate.types import ModelStatusTypesEnum
from frigate.util.builtin import serialize
from .functions.onnx import GenericONNXEmbedding, ModelTypeEnum
logger = logging.getLogger(__name__)
def get_metadata(event: Event) -> dict:
"""Extract valid event metadata."""
event_dict = model_to_dict(event)
return (
{
k: v
for k, v in event_dict.items()
if k not in ["thumbnail"]
and v is not None
and isinstance(v, (str, int, float, bool))
}
| {
k: v
for k, v in event_dict["data"].items()
if k not in ["description"]
and v is not None
and isinstance(v, (str, int, float, bool))
}
| {
# Metadata search doesn't support $contains
# and an event can have multiple zones, so
# we need to create a key for each zone
f"{k}_{x}": True
for k, v in event_dict.items()
if isinstance(v, list) and len(v) > 0
for x in v
if isinstance(x, str)
}
)
class Embeddings:
Use sqlite-vec extension instead of chromadb for embeddings (#14163) * swap sqlite_vec for chroma in requirements * load sqlite_vec in embeddings manager * remove chroma and revamp Embeddings class for sqlite_vec * manual minilm onnx inference * remove chroma in clip model * migrate api from chroma to sqlite_vec * migrate event cleanup from chroma to sqlite_vec * migrate embedding maintainer from chroma to sqlite_vec * genai description for sqlite_vec * load sqlite_vec in main thread db * extend the SqliteQueueDatabase class and use peewee db.execute_sql * search with Event type for similarity * fix similarity search * install and add comment about transformers * fix normalization * add id filter * clean up * clean up * fully remove chroma and add transformers env var * readd uvicorn for fastapi * readd tokenizer parallelism env var * remove chroma from docs * remove chroma from UI * try removing custom pysqlite3 build * hard code limit * optimize queries * revert explore query * fix query * keep building pysqlite3 * single pass fetch and process * remove unnecessary re-embed * update deps * move SqliteVecQueueDatabase to db directory * make search thumbnail take up full size of results box * improve typing * improve model downloading and add status screen * daemon downloading thread * catch case when semantic search is disabled * fix typing * build sqlite_vec from source * resolve conflict * file permissions * try build deps * remove sources * sources * fix thread start * include git in build * reorder embeddings after detectors are started * build with sqlite amalgamation * non-platform specific * use wget instead of curl * remove unzip -d * remove sqlite_vec from requirements and load the compiled version * fix build * avoid race in db connection * add scale_factor and bias to description zscore normalization
2024-10-07 22:30:45 +02:00
"""SQLite-vec embeddings database."""
def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase) -> None:
self.config = config
Use sqlite-vec extension instead of chromadb for embeddings (#14163) * swap sqlite_vec for chroma in requirements * load sqlite_vec in embeddings manager * remove chroma and revamp Embeddings class for sqlite_vec * manual minilm onnx inference * remove chroma in clip model * migrate api from chroma to sqlite_vec * migrate event cleanup from chroma to sqlite_vec * migrate embedding maintainer from chroma to sqlite_vec * genai description for sqlite_vec * load sqlite_vec in main thread db * extend the SqliteQueueDatabase class and use peewee db.execute_sql * search with Event type for similarity * fix similarity search * install and add comment about transformers * fix normalization * add id filter * clean up * clean up * fully remove chroma and add transformers env var * readd uvicorn for fastapi * readd tokenizer parallelism env var * remove chroma from docs * remove chroma from UI * try removing custom pysqlite3 build * hard code limit * optimize queries * revert explore query * fix query * keep building pysqlite3 * single pass fetch and process * remove unnecessary re-embed * update deps * move SqliteVecQueueDatabase to db directory * make search thumbnail take up full size of results box * improve typing * improve model downloading and add status screen * daemon downloading thread * catch case when semantic search is disabled * fix typing * build sqlite_vec from source * resolve conflict * file permissions * try build deps * remove sources * sources * fix thread start * include git in build * reorder embeddings after detectors are started * build with sqlite amalgamation * non-platform specific * use wget instead of curl * remove unzip -d * remove sqlite_vec from requirements and load the compiled version * fix build * avoid race in db connection * add scale_factor and bias to description zscore normalization
2024-10-07 22:30:45 +02:00
self.db = db
self.requestor = InterProcessRequestor()
# Create tables if they don't exist
self.db.create_embeddings_tables(self.config.face_recognition.enabled)
Use sqlite-vec extension instead of chromadb for embeddings (#14163) * swap sqlite_vec for chroma in requirements * load sqlite_vec in embeddings manager * remove chroma and revamp Embeddings class for sqlite_vec * manual minilm onnx inference * remove chroma in clip model * migrate api from chroma to sqlite_vec * migrate event cleanup from chroma to sqlite_vec * migrate embedding maintainer from chroma to sqlite_vec * genai description for sqlite_vec * load sqlite_vec in main thread db * extend the SqliteQueueDatabase class and use peewee db.execute_sql * search with Event type for similarity * fix similarity search * install and add comment about transformers * fix normalization * add id filter * clean up * clean up * fully remove chroma and add transformers env var * readd uvicorn for fastapi * readd tokenizer parallelism env var * remove chroma from docs * remove chroma from UI * try removing custom pysqlite3 build * hard code limit * optimize queries * revert explore query * fix query * keep building pysqlite3 * single pass fetch and process * remove unnecessary re-embed * update deps * move SqliteVecQueueDatabase to db directory * make search thumbnail take up full size of results box * improve typing * improve model downloading and add status screen * daemon downloading thread * catch case when semantic search is disabled * fix typing * build sqlite_vec from source * resolve conflict * file permissions * try build deps * remove sources * sources * fix thread start * include git in build * reorder embeddings after detectors are started * build with sqlite amalgamation * non-platform specific * use wget instead of curl * remove unzip -d * remove sqlite_vec from requirements and load the compiled version * fix build * avoid race in db connection * add scale_factor and bias to description zscore normalization
2024-10-07 22:30:45 +02:00
models = [
"jinaai/jina-clip-v1-text_model_fp16.onnx",
"jinaai/jina-clip-v1-tokenizer",
"jinaai/jina-clip-v1-vision_model_fp16.onnx"
if config.semantic_search.model_size == "large"
else "jinaai/jina-clip-v1-vision_model_quantized.onnx",
"jinaai/jina-clip-v1-preprocessor_config.json",
License plate recognition (ALPR) backend (#14564) * Update version * Face recognition backend (#14495) * Add basic config and face recognition table * Reconfigure updates processing to handle face * Crop frame to face box * Implement face embedding calculation * Get matching face embeddings * Add support face recognition based on existing faces * Use arcface face embeddings instead of generic embeddings model * Add apis for managing faces * Implement face uploading API * Build out more APIs * Add min area config * Handle larger images * Add more debug logs * fix calculation * Reduce timeout * Small tweaks * Use webp images * Use facenet model * Improve face recognition (#14537) * Increase requirements for face to be set * Manage faces properly * Add basic docs * Simplify * Separate out face recognition frome semantic search * Update docs * Formatting * Fix access (#14540) * Face detection (#14544) * Add support for face detection * Add support for detecting faces during registration * Set body size to be larger * Undo * Update version * Face recognition backend (#14495) * Add basic config and face recognition table * Reconfigure updates processing to handle face * Crop frame to face box * Implement face embedding calculation * Get matching face embeddings * Add support face recognition based on existing faces * Use arcface face embeddings instead of generic embeddings model * Add apis for managing faces * Implement face uploading API * Build out more APIs * Add min area config * Handle larger images * Add more debug logs * fix calculation * Reduce timeout * Small tweaks * Use webp images * Use facenet model * Improve face recognition (#14537) * Increase requirements for face to be set * Manage faces properly * Add basic docs * Simplify * Separate out face recognition frome semantic search * Update docs * Formatting * Fix access (#14540) * Face detection (#14544) * Add support for face detection * Add support for detecting faces during registration * Set body size to be larger * Undo * initial foundation for alpr with paddleocr * initial foundation for alpr with paddleocr * initial foundation for alpr with paddleocr * config * config * lpr maintainer * clean up * clean up * fix processing * don't process for stationary cars * fix order * fixes * check for known plates * improved length and character by character confidence * model fixes and small tweaks * docs * placeholder for non frigate+ model lp detection --------- Co-authored-by: Nicolas Mowen <nickmowen213@gmail.com>
2024-10-26 19:07:45 +02:00
"facenet-facenet.onnx",
"paddleocr-onnx-detection.onnx",
"paddleocr-onnx-classification.onnx",
"paddleocr-onnx-recognition.onnx",
Use sqlite-vec extension instead of chromadb for embeddings (#14163) * swap sqlite_vec for chroma in requirements * load sqlite_vec in embeddings manager * remove chroma and revamp Embeddings class for sqlite_vec * manual minilm onnx inference * remove chroma in clip model * migrate api from chroma to sqlite_vec * migrate event cleanup from chroma to sqlite_vec * migrate embedding maintainer from chroma to sqlite_vec * genai description for sqlite_vec * load sqlite_vec in main thread db * extend the SqliteQueueDatabase class and use peewee db.execute_sql * search with Event type for similarity * fix similarity search * install and add comment about transformers * fix normalization * add id filter * clean up * clean up * fully remove chroma and add transformers env var * readd uvicorn for fastapi * readd tokenizer parallelism env var * remove chroma from docs * remove chroma from UI * try removing custom pysqlite3 build * hard code limit * optimize queries * revert explore query * fix query * keep building pysqlite3 * single pass fetch and process * remove unnecessary re-embed * update deps * move SqliteVecQueueDatabase to db directory * make search thumbnail take up full size of results box * improve typing * improve model downloading and add status screen * daemon downloading thread * catch case when semantic search is disabled * fix typing * build sqlite_vec from source * resolve conflict * file permissions * try build deps * remove sources * sources * fix thread start * include git in build * reorder embeddings after detectors are started * build with sqlite amalgamation * non-platform specific * use wget instead of curl * remove unzip -d * remove sqlite_vec from requirements and load the compiled version * fix build * avoid race in db connection * add scale_factor and bias to description zscore normalization
2024-10-07 22:30:45 +02:00
]
for model in models:
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": model,
"state": ModelStatusTypesEnum.not_downloaded,
},
)
self.text_embedding = GenericONNXEmbedding(
model_name="jinaai/jina-clip-v1",
model_file="text_model_fp16.onnx",
tokenizer_file="tokenizer",
download_urls={
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
},
model_size=config.semantic_search.model_size,
model_type=ModelTypeEnum.text,
requestor=self.requestor,
device="CPU",
Use sqlite-vec extension instead of chromadb for embeddings (#14163) * swap sqlite_vec for chroma in requirements * load sqlite_vec in embeddings manager * remove chroma and revamp Embeddings class for sqlite_vec * manual minilm onnx inference * remove chroma in clip model * migrate api from chroma to sqlite_vec * migrate event cleanup from chroma to sqlite_vec * migrate embedding maintainer from chroma to sqlite_vec * genai description for sqlite_vec * load sqlite_vec in main thread db * extend the SqliteQueueDatabase class and use peewee db.execute_sql * search with Event type for similarity * fix similarity search * install and add comment about transformers * fix normalization * add id filter * clean up * clean up * fully remove chroma and add transformers env var * readd uvicorn for fastapi * readd tokenizer parallelism env var * remove chroma from docs * remove chroma from UI * try removing custom pysqlite3 build * hard code limit * optimize queries * revert explore query * fix query * keep building pysqlite3 * single pass fetch and process * remove unnecessary re-embed * update deps * move SqliteVecQueueDatabase to db directory * make search thumbnail take up full size of results box * improve typing * improve model downloading and add status screen * daemon downloading thread * catch case when semantic search is disabled * fix typing * build sqlite_vec from source * resolve conflict * file permissions * try build deps * remove sources * sources * fix thread start * include git in build * reorder embeddings after detectors are started * build with sqlite amalgamation * non-platform specific * use wget instead of curl * remove unzip -d * remove sqlite_vec from requirements and load the compiled version * fix build * avoid race in db connection * add scale_factor and bias to description zscore normalization
2024-10-07 22:30:45 +02:00
)
model_file = (
"vision_model_fp16.onnx"
if self.config.semantic_search.model_size == "large"
else "vision_model_quantized.onnx"
)
download_urls = {
model_file: f"https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/{model_file}",
"preprocessor_config.json": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/preprocessor_config.json",
}
self.vision_embedding = GenericONNXEmbedding(
model_name="jinaai/jina-clip-v1",
model_file=model_file,
download_urls=download_urls,
2024-10-23 17:26:03 +02:00
model_size=config.semantic_search.model_size,
model_type=ModelTypeEnum.vision,
requestor=self.requestor,
2024-10-23 17:26:03 +02:00
device="GPU" if config.semantic_search.model_size == "large" else "CPU",
Use sqlite-vec extension instead of chromadb for embeddings (#14163) * swap sqlite_vec for chroma in requirements * load sqlite_vec in embeddings manager * remove chroma and revamp Embeddings class for sqlite_vec * manual minilm onnx inference * remove chroma in clip model * migrate api from chroma to sqlite_vec * migrate event cleanup from chroma to sqlite_vec * migrate embedding maintainer from chroma to sqlite_vec * genai description for sqlite_vec * load sqlite_vec in main thread db * extend the SqliteQueueDatabase class and use peewee db.execute_sql * search with Event type for similarity * fix similarity search * install and add comment about transformers * fix normalization * add id filter * clean up * clean up * fully remove chroma and add transformers env var * readd uvicorn for fastapi * readd tokenizer parallelism env var * remove chroma from docs * remove chroma from UI * try removing custom pysqlite3 build * hard code limit * optimize queries * revert explore query * fix query * keep building pysqlite3 * single pass fetch and process * remove unnecessary re-embed * update deps * move SqliteVecQueueDatabase to db directory * make search thumbnail take up full size of results box * improve typing * improve model downloading and add status screen * daemon downloading thread * catch case when semantic search is disabled * fix typing * build sqlite_vec from source * resolve conflict * file permissions * try build deps * remove sources * sources * fix thread start * include git in build * reorder embeddings after detectors are started * build with sqlite amalgamation * non-platform specific * use wget instead of curl * remove unzip -d * remove sqlite_vec from requirements and load the compiled version * fix build * avoid race in db connection * add scale_factor and bias to description zscore normalization
2024-10-07 22:30:45 +02:00
)
self.face_embedding = None
if self.config.face_recognition.enabled:
self.face_embedding = GenericONNXEmbedding(
model_name="facenet",
model_file="facenet.onnx",
download_urls={
"facenet.onnx": "https://github.com/NicolasSM-001/faceNet.onnx-/raw/refs/heads/main/faceNet.onnx",
"facedet.onnx": "https://github.com/opencv/opencv_zoo/raw/refs/heads/main/models/face_detection_yunet/face_detection_yunet_2023mar_int8.onnx",
},
model_size="large",
model_type=ModelTypeEnum.face,
requestor=self.requestor,
device="GPU",
)
License plate recognition (ALPR) backend (#14564) * Update version * Face recognition backend (#14495) * Add basic config and face recognition table * Reconfigure updates processing to handle face * Crop frame to face box * Implement face embedding calculation * Get matching face embeddings * Add support face recognition based on existing faces * Use arcface face embeddings instead of generic embeddings model * Add apis for managing faces * Implement face uploading API * Build out more APIs * Add min area config * Handle larger images * Add more debug logs * fix calculation * Reduce timeout * Small tweaks * Use webp images * Use facenet model * Improve face recognition (#14537) * Increase requirements for face to be set * Manage faces properly * Add basic docs * Simplify * Separate out face recognition frome semantic search * Update docs * Formatting * Fix access (#14540) * Face detection (#14544) * Add support for face detection * Add support for detecting faces during registration * Set body size to be larger * Undo * Update version * Face recognition backend (#14495) * Add basic config and face recognition table * Reconfigure updates processing to handle face * Crop frame to face box * Implement face embedding calculation * Get matching face embeddings * Add support face recognition based on existing faces * Use arcface face embeddings instead of generic embeddings model * Add apis for managing faces * Implement face uploading API * Build out more APIs * Add min area config * Handle larger images * Add more debug logs * fix calculation * Reduce timeout * Small tweaks * Use webp images * Use facenet model * Improve face recognition (#14537) * Increase requirements for face to be set * Manage faces properly * Add basic docs * Simplify * Separate out face recognition frome semantic search * Update docs * Formatting * Fix access (#14540) * Face detection (#14544) * Add support for face detection * Add support for detecting faces during registration * Set body size to be larger * Undo * initial foundation for alpr with paddleocr * initial foundation for alpr with paddleocr * initial foundation for alpr with paddleocr * config * config * lpr maintainer * clean up * clean up * fix processing * don't process for stationary cars * fix order * fixes * check for known plates * improved length and character by character confidence * model fixes and small tweaks * docs * placeholder for non frigate+ model lp detection --------- Co-authored-by: Nicolas Mowen <nickmowen213@gmail.com>
2024-10-26 19:07:45 +02:00
self.lpr_detection_model = None
self.lpr_classification_model = None
self.lpr_recognition_model = None
if self.config.lpr.enabled:
self.lpr_detection_model = GenericONNXEmbedding(
model_name="paddleocr-onnx",
model_file="detection.onnx",
download_urls={
"detection.onnx": "https://github.com/hawkeye217/paddleocr-onnx/raw/refs/heads/master/models/detection.onnx"
},
model_size="large",
model_type=ModelTypeEnum.alpr_detect,
requestor=self.requestor,
device="CPU",
)
self.lpr_classification_model = GenericONNXEmbedding(
model_name="paddleocr-onnx",
model_file="classification.onnx",
download_urls={
"classification.onnx": "https://github.com/hawkeye217/paddleocr-onnx/raw/refs/heads/master/models/classification.onnx"
},
model_size="large",
model_type=ModelTypeEnum.alpr_classify,
requestor=self.requestor,
device="CPU",
)
self.lpr_recognition_model = GenericONNXEmbedding(
model_name="paddleocr-onnx",
model_file="recognition.onnx",
download_urls={
"recognition.onnx": "https://github.com/hawkeye217/paddleocr-onnx/raw/refs/heads/master/models/recognition.onnx"
},
model_size="large",
model_type=ModelTypeEnum.alpr_recognize,
requestor=self.requestor,
device="CPU",
)
def embed_thumbnail(
self, event_id: str, thumbnail: bytes, upsert: bool = True
) -> ndarray:
"""Embed thumbnail and optionally insert into DB.
@param: event_id in Events DB
@param: thumbnail bytes in jpg format
@param: upsert If embedding should be upserted into vec DB
"""
Use sqlite-vec extension instead of chromadb for embeddings (#14163) * swap sqlite_vec for chroma in requirements * load sqlite_vec in embeddings manager * remove chroma and revamp Embeddings class for sqlite_vec * manual minilm onnx inference * remove chroma in clip model * migrate api from chroma to sqlite_vec * migrate event cleanup from chroma to sqlite_vec * migrate embedding maintainer from chroma to sqlite_vec * genai description for sqlite_vec * load sqlite_vec in main thread db * extend the SqliteQueueDatabase class and use peewee db.execute_sql * search with Event type for similarity * fix similarity search * install and add comment about transformers * fix normalization * add id filter * clean up * clean up * fully remove chroma and add transformers env var * readd uvicorn for fastapi * readd tokenizer parallelism env var * remove chroma from docs * remove chroma from UI * try removing custom pysqlite3 build * hard code limit * optimize queries * revert explore query * fix query * keep building pysqlite3 * single pass fetch and process * remove unnecessary re-embed * update deps * move SqliteVecQueueDatabase to db directory * make search thumbnail take up full size of results box * improve typing * improve model downloading and add status screen * daemon downloading thread * catch case when semantic search is disabled * fix typing * build sqlite_vec from source * resolve conflict * file permissions * try build deps * remove sources * sources * fix thread start * include git in build * reorder embeddings after detectors are started * build with sqlite amalgamation * non-platform specific * use wget instead of curl * remove unzip -d * remove sqlite_vec from requirements and load the compiled version * fix build * avoid race in db connection * add scale_factor and bias to description zscore normalization
2024-10-07 22:30:45 +02:00
# Convert thumbnail bytes to PIL Image
embedding = self.vision_embedding([thumbnail])[0]
if upsert:
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES(?, ?)
""",
(event_id, serialize(embedding)),
)
Use sqlite-vec extension instead of chromadb for embeddings (#14163) * swap sqlite_vec for chroma in requirements * load sqlite_vec in embeddings manager * remove chroma and revamp Embeddings class for sqlite_vec * manual minilm onnx inference * remove chroma in clip model * migrate api from chroma to sqlite_vec * migrate event cleanup from chroma to sqlite_vec * migrate embedding maintainer from chroma to sqlite_vec * genai description for sqlite_vec * load sqlite_vec in main thread db * extend the SqliteQueueDatabase class and use peewee db.execute_sql * search with Event type for similarity * fix similarity search * install and add comment about transformers * fix normalization * add id filter * clean up * clean up * fully remove chroma and add transformers env var * readd uvicorn for fastapi * readd tokenizer parallelism env var * remove chroma from docs * remove chroma from UI * try removing custom pysqlite3 build * hard code limit * optimize queries * revert explore query * fix query * keep building pysqlite3 * single pass fetch and process * remove unnecessary re-embed * update deps * move SqliteVecQueueDatabase to db directory * make search thumbnail take up full size of results box * improve typing * improve model downloading and add status screen * daemon downloading thread * catch case when semantic search is disabled * fix typing * build sqlite_vec from source * resolve conflict * file permissions * try build deps * remove sources * sources * fix thread start * include git in build * reorder embeddings after detectors are started * build with sqlite amalgamation * non-platform specific * use wget instead of curl * remove unzip -d * remove sqlite_vec from requirements and load the compiled version * fix build * avoid race in db connection * add scale_factor and bias to description zscore normalization
2024-10-07 22:30:45 +02:00
return embedding
def batch_embed_thumbnail(
self, event_thumbs: dict[str, bytes], upsert: bool = True
) -> list[ndarray]:
"""Embed thumbnails and optionally insert into DB.
@param: event_thumbs Map of Event IDs in DB to thumbnail bytes in jpg format
@param: upsert If embedding should be upserted into vec DB
"""
ids = list(event_thumbs.keys())
embeddings = self.vision_embedding(list(event_thumbs.values()))
if upsert:
items = []
for i in range(len(ids)):
items.append(ids[i])
items.append(serialize(embeddings[i]))
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES {}
""".format(", ".join(["(?, ?)"] * len(ids))),
items,
)
return embeddings
def embed_description(
self, event_id: str, description: str, upsert: bool = True
) -> ndarray:
embedding = self.text_embedding([description])[0]
if upsert:
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
VALUES(?, ?)
""",
(event_id, serialize(embedding)),
)
Use sqlite-vec extension instead of chromadb for embeddings (#14163) * swap sqlite_vec for chroma in requirements * load sqlite_vec in embeddings manager * remove chroma and revamp Embeddings class for sqlite_vec * manual minilm onnx inference * remove chroma in clip model * migrate api from chroma to sqlite_vec * migrate event cleanup from chroma to sqlite_vec * migrate embedding maintainer from chroma to sqlite_vec * genai description for sqlite_vec * load sqlite_vec in main thread db * extend the SqliteQueueDatabase class and use peewee db.execute_sql * search with Event type for similarity * fix similarity search * install and add comment about transformers * fix normalization * add id filter * clean up * clean up * fully remove chroma and add transformers env var * readd uvicorn for fastapi * readd tokenizer parallelism env var * remove chroma from docs * remove chroma from UI * try removing custom pysqlite3 build * hard code limit * optimize queries * revert explore query * fix query * keep building pysqlite3 * single pass fetch and process * remove unnecessary re-embed * update deps * move SqliteVecQueueDatabase to db directory * make search thumbnail take up full size of results box * improve typing * improve model downloading and add status screen * daemon downloading thread * catch case when semantic search is disabled * fix typing * build sqlite_vec from source * resolve conflict * file permissions * try build deps * remove sources * sources * fix thread start * include git in build * reorder embeddings after detectors are started * build with sqlite amalgamation * non-platform specific * use wget instead of curl * remove unzip -d * remove sqlite_vec from requirements and load the compiled version * fix build * avoid race in db connection * add scale_factor and bias to description zscore normalization
2024-10-07 22:30:45 +02:00
return embedding
def batch_embed_description(
self, event_descriptions: dict[str, str], upsert: bool = True
) -> ndarray:
# upsert embeddings one by one to avoid token limit
embeddings = []
for desc in event_descriptions.values():
embeddings.append(self.text_embedding([desc])[0])
if upsert:
ids = list(event_descriptions.keys())
items = []
for i in range(len(ids)):
items.append(ids[i])
items.append(serialize(embeddings[i]))
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
VALUES {}
""".format(", ".join(["(?, ?)"] * len(ids))),
items,
)
return embeddings
def embed_face(self, label: str, thumbnail: bytes, upsert: bool = False) -> ndarray:
embedding = self.face_embedding(thumbnail)[0]
if upsert:
rand_id = "".join(
random.choices(string.ascii_lowercase + string.digits, k=6)
)
id = f"{label}-{rand_id}"
# write face to library
folder = os.path.join(FACE_DIR, label)
file = os.path.join(folder, f"{id}.webp")
os.makedirs(folder, exist_ok=True)
# save face image
with open(file, "wb") as output:
output.write(thumbnail)
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_faces(id, face_embedding)
VALUES(?, ?)
""",
(id, serialize(embedding)),
)
return embedding
def reindex(self) -> None:
logger.info("Indexing tracked object embeddings...")
self.db.drop_embeddings_tables()
logger.debug("Dropped embeddings tables.")
self.db.create_embeddings_tables(self.config.face_recognition.enabled)
logger.debug("Created embeddings tables.")
# Delete the saved stats file
if os.path.exists(os.path.join(CONFIG_DIR, ".search_stats.json")):
os.remove(os.path.join(CONFIG_DIR, ".search_stats.json"))
st = time.time()
# Get total count of events to process
total_events = (
Event.select()
.where(
(Event.has_clip == True | Event.has_snapshot == True)
& Event.thumbnail.is_null(False)
)
.count()
)
batch_size = 32
current_page = 1
totals = {
"thumbnails": 0,
"descriptions": 0,
"processed_objects": total_events - 1 if total_events < batch_size else 0,
"total_objects": total_events,
"time_remaining": 0 if total_events < batch_size else -1,
"status": "indexing",
}
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
events = (
Event.select()
.where(
(Event.has_clip == True | Event.has_snapshot == True)
& Event.thumbnail.is_null(False)
)
.order_by(Event.start_time.desc())
.paginate(current_page, batch_size)
)
while len(events) > 0:
event: Event
batch_thumbs = {}
batch_descs = {}
for event in events:
batch_thumbs[event.id] = base64.b64decode(event.thumbnail)
totals["thumbnails"] += 1
if description := event.data.get("description", "").strip():
batch_descs[event.id] = description
totals["descriptions"] += 1
totals["processed_objects"] += 1
# run batch embedding
self.batch_embed_thumbnail(batch_thumbs)
if batch_descs:
self.batch_embed_description(batch_descs)
# report progress every batch so we don't spam the logs
progress = (totals["processed_objects"] / total_events) * 100
logger.debug(
"Processed %d/%d events (%.2f%% complete) | Thumbnails: %d, Descriptions: %d",
totals["processed_objects"],
total_events,
progress,
totals["thumbnails"],
totals["descriptions"],
)
# Calculate time remaining
elapsed_time = time.time() - st
avg_time_per_event = elapsed_time / totals["processed_objects"]
remaining_events = total_events - totals["processed_objects"]
time_remaining = avg_time_per_event * remaining_events
totals["time_remaining"] = int(time_remaining)
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
# Move to the next page
current_page += 1
events = (
Event.select()
.where(
(Event.has_clip == True | Event.has_snapshot == True)
& Event.thumbnail.is_null(False)
)
.order_by(Event.start_time.desc())
.paginate(current_page, batch_size)
)
logger.info(
"Embedded %d thumbnails and %d descriptions in %s seconds",
totals["thumbnails"],
totals["descriptions"],
round(time.time() - st, 1),
)
totals["status"] = "completed"
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)