blakeblackshear.frigate/frigate/embeddings/embeddings.py
Josh Hawkins 24ac9f3e5a
Use sqlite-vec extension instead of chromadb for embeddings (#14163)
* swap sqlite_vec for chroma in requirements

* load sqlite_vec in embeddings manager

* remove chroma and revamp Embeddings class for sqlite_vec

* manual minilm onnx inference

* remove chroma in clip model

* migrate api from chroma to sqlite_vec

* migrate event cleanup from chroma to sqlite_vec

* migrate embedding maintainer from chroma to sqlite_vec

* genai description for sqlite_vec

* load sqlite_vec in main thread db

* extend the SqliteQueueDatabase class and use peewee db.execute_sql

* search with Event type for similarity

* fix similarity search

* install and add comment about transformers

* fix normalization

* add id filter

* clean up

* clean up

* fully remove chroma and add transformers env var

* readd uvicorn for fastapi

* readd tokenizer parallelism env var

* remove chroma from docs

* remove chroma from UI

* try removing custom pysqlite3 build

* hard code limit

* optimize queries

* revert explore query

* fix query

* keep building pysqlite3

* single pass fetch and process

* remove unnecessary re-embed

* update deps

* move SqliteVecQueueDatabase to db directory

* make search thumbnail take up full size of results box

* improve typing

* improve model downloading and add status screen

* daemon downloading thread

* catch case when semantic search is disabled

* fix typing

* build sqlite_vec from source

* resolve conflict

* file permissions

* try build deps

* remove sources

* sources

* fix thread start

* include git in build

* reorder embeddings after detectors are started

* build with sqlite amalgamation

* non-platform specific

* use wget instead of curl

* remove unzip -d

* remove sqlite_vec from requirements and load the compiled version

* fix build

* avoid race in db connection

* add scale_factor and bias to description zscore normalization
2024-10-07 14:30:45 -06:00

294 lines
9.2 KiB
Python

"""SQLite-vec embeddings database."""
import base64
import io
import logging
import struct
import time
from typing import List, Tuple, Union
from PIL import Image
from playhouse.shortcuts import model_to_dict
from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import UPDATE_MODEL_STATE
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event
from frigate.types import ModelStatusTypesEnum
from .functions.clip import ClipEmbedding
from .functions.minilm_l6_v2 import MiniLMEmbedding
logger = logging.getLogger(__name__)
def get_metadata(event: Event) -> dict:
"""Extract valid event metadata."""
event_dict = model_to_dict(event)
return (
{
k: v
for k, v in event_dict.items()
if k not in ["thumbnail"]
and v is not None
and isinstance(v, (str, int, float, bool))
}
| {
k: v
for k, v in event_dict["data"].items()
if k not in ["description"]
and v is not None
and isinstance(v, (str, int, float, bool))
}
| {
# Metadata search doesn't support $contains
# and an event can have multiple zones, so
# we need to create a key for each zone
f"{k}_{x}": True
for k, v in event_dict.items()
if isinstance(v, list) and len(v) > 0
for x in v
if isinstance(x, str)
}
)
def serialize(vector: List[float]) -> bytes:
"""Serializes a list of floats into a compact "raw bytes" format"""
return struct.pack("%sf" % len(vector), *vector)
def deserialize(bytes_data: bytes) -> List[float]:
"""Deserializes a compact "raw bytes" format into a list of floats"""
return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
class Embeddings:
"""SQLite-vec embeddings database."""
def __init__(self, db: SqliteVecQueueDatabase) -> None:
self.db = db
self.requestor = InterProcessRequestor()
# Create tables if they don't exist
self._create_tables()
models = [
"sentence-transformers/all-MiniLM-L6-v2-model.onnx",
"sentence-transformers/all-MiniLM-L6-v2-tokenizer",
"clip-clip_image_model_vitb32.onnx",
"clip-clip_text_model_vitb32.onnx",
]
for model in models:
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": model,
"state": ModelStatusTypesEnum.not_downloaded,
},
)
self.clip_embedding = ClipEmbedding(
preferred_providers=["CPUExecutionProvider"]
)
self.minilm_embedding = MiniLMEmbedding(
preferred_providers=["CPUExecutionProvider"],
)
def _create_tables(self):
# Create vec0 virtual table for thumbnail embeddings
self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY,
thumbnail_embedding FLOAT[512]
);
""")
# Create vec0 virtual table for description embeddings
self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY,
description_embedding FLOAT[384]
);
""")
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
# Convert thumbnail bytes to PIL Image
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
# Generate embedding using CLIP
embedding = self.clip_embedding([image])[0]
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES(?, ?)
""",
(event_id, serialize(embedding)),
)
return embedding
def upsert_description(self, event_id: str, description: str):
# Generate embedding using MiniLM
embedding = self.minilm_embedding([description])[0]
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
VALUES(?, ?)
""",
(event_id, serialize(embedding)),
)
return embedding
def delete_thumbnail(self, event_ids: List[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.db.execute_sql(
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
)
def delete_description(self, event_ids: List[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.db.execute_sql(
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
)
def search_thumbnail(
self, query: Union[Event, str], event_ids: List[str] = None
) -> List[Tuple[str, float]]:
if query.__class__ == Event:
cursor = self.db.execute_sql(
"""
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
""",
[query.id],
)
row = cursor.fetchone() if cursor else None
if row:
query_embedding = deserialize(
row[0]
) # Deserialize the thumbnail embedding
else:
# If no embedding found, generate it and return it
thumbnail = base64.b64decode(query.thumbnail)
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
else:
query_embedding = self.clip_embedding([query])[0]
sql_query = """
SELECT
id,
distance
FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def search_description(
self, query_text: str, event_ids: List[str] = None
) -> List[Tuple[str, float]]:
query_embedding = self.minilm_embedding([query_text])[0]
# Prepare the base SQL query
sql_query = """
SELECT
id,
distance
FROM vec_descriptions
WHERE description_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def reindex(self) -> None:
logger.info("Indexing event embeddings...")
st = time.time()
totals = {
"thumb": 0,
"desc": 0,
}
batch_size = 100
current_page = 1
events = (
Event.select()
.where(
(Event.has_clip == True | Event.has_snapshot == True)
& Event.thumbnail.is_null(False)
)
.order_by(Event.start_time.desc())
.paginate(current_page, batch_size)
)
while len(events) > 0:
event: Event
for event in events:
thumbnail = base64.b64decode(event.thumbnail)
self.upsert_thumbnail(event.id, thumbnail)
totals["thumb"] += 1
if description := event.data.get("description", "").strip():
totals["desc"] += 1
self.upsert_description(event.id, description)
current_page += 1
events = (
Event.select()
.where(
(Event.has_clip == True | Event.has_snapshot == True)
& Event.thumbnail.is_null(False)
)
.order_by(Event.start_time.desc())
.paginate(current_page, batch_size)
)
logger.info(
"Embedded %d thumbnails and %d descriptions in %s seconds",
totals["thumb"],
totals["desc"],
time.time() - st,
)