mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-21 19:07:46 +01:00
1ec459ea3a
* fixes * more readable loops * more robust key check and warning message * ensure we get reindex progress on mount * use correct var for length
305 lines
9.7 KiB
Python
305 lines
9.7 KiB
Python
"""SQLite-vec embeddings database."""
|
|
|
|
import base64
|
|
import io
|
|
import logging
|
|
import os
|
|
import time
|
|
|
|
from numpy import ndarray
|
|
from PIL import Image
|
|
from playhouse.shortcuts import model_to_dict
|
|
|
|
from frigate.comms.inter_process import InterProcessRequestor
|
|
from frigate.config.semantic_search import SemanticSearchConfig
|
|
from frigate.const import (
|
|
CONFIG_DIR,
|
|
UPDATE_EMBEDDINGS_REINDEX_PROGRESS,
|
|
UPDATE_MODEL_STATE,
|
|
)
|
|
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
|
from frigate.models import Event
|
|
from frigate.types import ModelStatusTypesEnum
|
|
from frigate.util.builtin import serialize
|
|
|
|
from .functions.onnx import GenericONNXEmbedding
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_metadata(event: Event) -> dict:
|
|
"""Extract valid event metadata."""
|
|
event_dict = model_to_dict(event)
|
|
return (
|
|
{
|
|
k: v
|
|
for k, v in event_dict.items()
|
|
if k not in ["thumbnail"]
|
|
and v is not None
|
|
and isinstance(v, (str, int, float, bool))
|
|
}
|
|
| {
|
|
k: v
|
|
for k, v in event_dict["data"].items()
|
|
if k not in ["description"]
|
|
and v is not None
|
|
and isinstance(v, (str, int, float, bool))
|
|
}
|
|
| {
|
|
# Metadata search doesn't support $contains
|
|
# and an event can have multiple zones, so
|
|
# we need to create a key for each zone
|
|
f"{k}_{x}": True
|
|
for k, v in event_dict.items()
|
|
if isinstance(v, list) and len(v) > 0
|
|
for x in v
|
|
if isinstance(x, str)
|
|
}
|
|
)
|
|
|
|
|
|
class Embeddings:
|
|
"""SQLite-vec embeddings database."""
|
|
|
|
def __init__(
|
|
self, config: SemanticSearchConfig, db: SqliteVecQueueDatabase
|
|
) -> None:
|
|
self.config = config
|
|
self.db = db
|
|
self.requestor = InterProcessRequestor()
|
|
|
|
# Create tables if they don't exist
|
|
self.db.create_embeddings_tables()
|
|
|
|
models = [
|
|
"jinaai/jina-clip-v1-text_model_fp16.onnx",
|
|
"jinaai/jina-clip-v1-tokenizer",
|
|
"jinaai/jina-clip-v1-vision_model_fp16.onnx"
|
|
if config.model_size == "large"
|
|
else "jinaai/jina-clip-v1-vision_model_quantized.onnx",
|
|
"jinaai/jina-clip-v1-preprocessor_config.json",
|
|
]
|
|
|
|
for model in models:
|
|
self.requestor.send_data(
|
|
UPDATE_MODEL_STATE,
|
|
{
|
|
"model": model,
|
|
"state": ModelStatusTypesEnum.not_downloaded,
|
|
},
|
|
)
|
|
|
|
self.text_embedding = GenericONNXEmbedding(
|
|
model_name="jinaai/jina-clip-v1",
|
|
model_file="text_model_fp16.onnx",
|
|
tokenizer_file="tokenizer",
|
|
download_urls={
|
|
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
|
|
},
|
|
model_size=config.model_size,
|
|
model_type="text",
|
|
requestor=self.requestor,
|
|
device="CPU",
|
|
)
|
|
|
|
model_file = (
|
|
"vision_model_fp16.onnx"
|
|
if self.config.model_size == "large"
|
|
else "vision_model_quantized.onnx"
|
|
)
|
|
|
|
download_urls = {
|
|
model_file: f"https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/{model_file}",
|
|
"preprocessor_config.json": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/preprocessor_config.json",
|
|
}
|
|
|
|
self.vision_embedding = GenericONNXEmbedding(
|
|
model_name="jinaai/jina-clip-v1",
|
|
model_file=model_file,
|
|
download_urls=download_urls,
|
|
model_size=config.model_size,
|
|
model_type="vision",
|
|
requestor=self.requestor,
|
|
device="GPU" if config.model_size == "large" else "CPU",
|
|
)
|
|
|
|
def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray:
|
|
# Convert thumbnail bytes to PIL Image
|
|
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
|
embedding = self.vision_embedding([image])[0]
|
|
|
|
self.db.execute_sql(
|
|
"""
|
|
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
|
VALUES(?, ?)
|
|
""",
|
|
(event_id, serialize(embedding)),
|
|
)
|
|
|
|
return embedding
|
|
|
|
def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]:
|
|
images = [
|
|
Image.open(io.BytesIO(thumb)).convert("RGB")
|
|
for thumb in event_thumbs.values()
|
|
]
|
|
ids = list(event_thumbs.keys())
|
|
embeddings = self.vision_embedding(images)
|
|
|
|
items = []
|
|
|
|
for i in range(len(ids)):
|
|
items.append(ids[i])
|
|
items.append(serialize(embeddings[i]))
|
|
|
|
self.db.execute_sql(
|
|
"""
|
|
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
|
VALUES {}
|
|
""".format(", ".join(["(?, ?)"] * len(ids))),
|
|
items,
|
|
)
|
|
return embeddings
|
|
|
|
def upsert_description(self, event_id: str, description: str) -> ndarray:
|
|
embedding = self.text_embedding([description])[0]
|
|
self.db.execute_sql(
|
|
"""
|
|
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
|
VALUES(?, ?)
|
|
""",
|
|
(event_id, serialize(embedding)),
|
|
)
|
|
|
|
return embedding
|
|
|
|
def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray:
|
|
embeddings = self.text_embedding(list(event_descriptions.values()))
|
|
ids = list(event_descriptions.keys())
|
|
|
|
items = []
|
|
|
|
for i in range(len(ids)):
|
|
items.append(ids[i])
|
|
items.append(serialize(embeddings[i]))
|
|
|
|
self.db.execute_sql(
|
|
"""
|
|
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
|
VALUES {}
|
|
""".format(", ".join(["(?, ?)"] * len(ids))),
|
|
items,
|
|
)
|
|
|
|
return embeddings
|
|
|
|
def reindex(self) -> None:
|
|
logger.info("Indexing tracked object embeddings...")
|
|
|
|
self.db.drop_embeddings_tables()
|
|
logger.debug("Dropped embeddings tables.")
|
|
self.db.create_embeddings_tables()
|
|
logger.debug("Created embeddings tables.")
|
|
|
|
# Delete the saved stats file
|
|
if os.path.exists(os.path.join(CONFIG_DIR, ".search_stats.json")):
|
|
os.remove(os.path.join(CONFIG_DIR, ".search_stats.json"))
|
|
|
|
st = time.time()
|
|
|
|
# Get total count of events to process
|
|
total_events = (
|
|
Event.select()
|
|
.where(
|
|
(Event.has_clip == True | Event.has_snapshot == True)
|
|
& Event.thumbnail.is_null(False)
|
|
)
|
|
.count()
|
|
)
|
|
|
|
batch_size = 32
|
|
current_page = 1
|
|
|
|
totals = {
|
|
"thumbnails": 0,
|
|
"descriptions": 0,
|
|
"processed_objects": total_events - 1 if total_events < batch_size else 0,
|
|
"total_objects": total_events,
|
|
"time_remaining": 0 if total_events < batch_size else -1,
|
|
"status": "indexing",
|
|
}
|
|
|
|
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
|
|
|
|
events = (
|
|
Event.select()
|
|
.where(
|
|
(Event.has_clip == True | Event.has_snapshot == True)
|
|
& Event.thumbnail.is_null(False)
|
|
)
|
|
.order_by(Event.start_time.desc())
|
|
.paginate(current_page, batch_size)
|
|
)
|
|
|
|
while len(events) > 0:
|
|
event: Event
|
|
batch_thumbs = {}
|
|
batch_descs = {}
|
|
for event in events:
|
|
batch_thumbs[event.id] = base64.b64decode(event.thumbnail)
|
|
totals["thumbnails"] += 1
|
|
|
|
if description := event.data.get("description", "").strip():
|
|
batch_descs[event.id] = description
|
|
totals["descriptions"] += 1
|
|
|
|
totals["processed_objects"] += 1
|
|
|
|
# run batch embedding
|
|
self.batch_upsert_thumbnail(batch_thumbs)
|
|
|
|
if batch_descs:
|
|
self.batch_upsert_description(batch_descs)
|
|
|
|
# report progress every batch so we don't spam the logs
|
|
progress = (totals["processed_objects"] / total_events) * 100
|
|
logger.debug(
|
|
"Processed %d/%d events (%.2f%% complete) | Thumbnails: %d, Descriptions: %d",
|
|
totals["processed_objects"],
|
|
total_events,
|
|
progress,
|
|
totals["thumbnails"],
|
|
totals["descriptions"],
|
|
)
|
|
|
|
# Calculate time remaining
|
|
elapsed_time = time.time() - st
|
|
avg_time_per_event = elapsed_time / totals["processed_objects"]
|
|
remaining_events = total_events - totals["processed_objects"]
|
|
time_remaining = avg_time_per_event * remaining_events
|
|
totals["time_remaining"] = int(time_remaining)
|
|
|
|
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
|
|
|
|
# Move to the next page
|
|
current_page += 1
|
|
events = (
|
|
Event.select()
|
|
.where(
|
|
(Event.has_clip == True | Event.has_snapshot == True)
|
|
& Event.thumbnail.is_null(False)
|
|
)
|
|
.order_by(Event.start_time.desc())
|
|
.paginate(current_page, batch_size)
|
|
)
|
|
|
|
logger.info(
|
|
"Embedded %d thumbnails and %d descriptions in %s seconds",
|
|
totals["thumbnails"],
|
|
totals["descriptions"],
|
|
round(time.time() - st, 1),
|
|
)
|
|
totals["status"] = "completed"
|
|
|
|
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
|