blakeblackshear.frigate/frigate/embeddings/maintainer.py
Josh Hawkins 24ac9f3e5a
Use sqlite-vec extension instead of chromadb for embeddings (#14163)
* swap sqlite_vec for chroma in requirements

* load sqlite_vec in embeddings manager

* remove chroma and revamp Embeddings class for sqlite_vec

* manual minilm onnx inference

* remove chroma in clip model

* migrate api from chroma to sqlite_vec

* migrate event cleanup from chroma to sqlite_vec

* migrate embedding maintainer from chroma to sqlite_vec

* genai description for sqlite_vec

* load sqlite_vec in main thread db

* extend the SqliteQueueDatabase class and use peewee db.execute_sql

* search with Event type for similarity

* fix similarity search

* install and add comment about transformers

* fix normalization

* add id filter

* clean up

* clean up

* fully remove chroma and add transformers env var

* readd uvicorn for fastapi

* readd tokenizer parallelism env var

* remove chroma from docs

* remove chroma from UI

* try removing custom pysqlite3 build

* hard code limit

* optimize queries

* revert explore query

* fix query

* keep building pysqlite3

* single pass fetch and process

* remove unnecessary re-embed

* update deps

* move SqliteVecQueueDatabase to db directory

* make search thumbnail take up full size of results box

* improve typing

* improve model downloading and add status screen

* daemon downloading thread

* catch case when semantic search is disabled

* fix typing

* build sqlite_vec from source

* resolve conflict

* file permissions

* try build deps

* remove sources

* sources

* fix thread start

* include git in build

* reorder embeddings after detectors are started

* build with sqlite amalgamation

* non-platform specific

* use wget instead of curl

* remove unzip -d

* remove sqlite_vec from requirements and load the compiled version

* fix build

* avoid race in db connection

* add scale_factor and bias to description zscore normalization
2024-10-07 14:30:45 -06:00

303 lines
11 KiB
Python

"""Maintain embeddings in SQLite-vec."""
import base64
import logging
import os
import threading
from multiprocessing.synchronize import Event as MpEvent
from typing import Optional
import cv2
import numpy as np
from peewee import DoesNotExist
from playhouse.sqliteq import SqliteQueueDatabase
from frigate.comms.event_metadata_updater import (
EventMetadataSubscriber,
EventMetadataTypeEnum,
)
from frigate.comms.events_updater import EventEndSubscriber, EventUpdateSubscriber
from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig
from frigate.const import CLIPS_DIR, UPDATE_EVENT_DESCRIPTION
from frigate.events.types import EventTypeEnum
from frigate.genai import get_genai_client
from frigate.models import Event
from frigate.util.image import SharedMemoryFrameManager, calculate_region
from .embeddings import Embeddings
logger = logging.getLogger(__name__)
class EmbeddingMaintainer(threading.Thread):
"""Handle embedding queue and post event updates."""
def __init__(
self,
db: SqliteQueueDatabase,
config: FrigateConfig,
stop_event: MpEvent,
) -> None:
threading.Thread.__init__(self)
self.name = "embeddings_maintainer"
self.config = config
self.embeddings = Embeddings(db)
self.event_subscriber = EventUpdateSubscriber()
self.event_end_subscriber = EventEndSubscriber()
self.event_metadata_subscriber = EventMetadataSubscriber(
EventMetadataTypeEnum.regenerate_description
)
self.frame_manager = SharedMemoryFrameManager()
# create communication for updating event descriptions
self.requestor = InterProcessRequestor()
self.stop_event = stop_event
self.tracked_events = {}
self.genai_client = get_genai_client(config.genai)
def run(self) -> None:
"""Maintain a SQLite-vec database for semantic search."""
while not self.stop_event.is_set():
self._process_updates()
self._process_finalized()
self._process_event_metadata()
self.event_subscriber.stop()
self.event_end_subscriber.stop()
self.event_metadata_subscriber.stop()
self.requestor.stop()
logger.info("Exiting embeddings maintenance...")
def _process_updates(self) -> None:
"""Process event updates"""
update = self.event_subscriber.check_for_update()
if update is None:
return
source_type, _, camera, data = update
if not camera or source_type != EventTypeEnum.tracked_object:
return
camera_config = self.config.cameras[camera]
if data["id"] not in self.tracked_events:
self.tracked_events[data["id"]] = []
# Create our own thumbnail based on the bounding box and the frame time
try:
frame_id = f"{camera}{data['frame_time']}"
yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv)
if yuv_frame is not None:
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
self.tracked_events[data["id"]].append(data)
self.frame_manager.close(frame_id)
except FileNotFoundError:
pass
def _process_finalized(self) -> None:
"""Process the end of an event."""
while True:
ended = self.event_end_subscriber.check_for_update()
if ended == None:
break
event_id, camera, updated_db = ended
camera_config = self.config.cameras[camera]
if updated_db:
try:
event: Event = Event.get(Event.id == event_id)
except DoesNotExist:
continue
# Skip the event if not an object
if event.data.get("type") != "object":
continue
# Extract valid thumbnail
thumbnail = base64.b64decode(event.thumbnail)
# Embed the thumbnail
self._embed_thumbnail(event_id, thumbnail)
if (
camera_config.genai.enabled
and self.genai_client is not None
and event.data.get("description") is None
and (
not camera_config.genai.objects
or event.label in camera_config.genai.objects
)
and (
not camera_config.genai.required_zones
or set(event.zones) & set(camera_config.genai.required_zones)
)
):
logger.debug(
f"Description generation for {event}, has_snapshot: {event.has_snapshot}"
)
if event.has_snapshot and camera_config.genai.use_snapshot:
with open(
os.path.join(CLIPS_DIR, f"{event.camera}-{event.id}.jpg"),
"rb",
) as image_file:
snapshot_image = image_file.read()
img = cv2.imdecode(
np.frombuffer(snapshot_image, dtype=np.int8),
cv2.IMREAD_COLOR,
)
# crop snapshot based on region before sending off to genai
height, width = img.shape[:2]
x1_rel, y1_rel, width_rel, height_rel = event.data["region"]
x1, y1 = int(x1_rel * width), int(y1_rel * height)
cropped_image = img[
y1 : y1 + int(height_rel * height),
x1 : x1 + int(width_rel * width),
]
_, buffer = cv2.imencode(".jpg", cropped_image)
snapshot_image = buffer.tobytes()
embed_image = (
[snapshot_image]
if event.has_snapshot and camera_config.genai.use_snapshot
else (
[thumbnail for data in self.tracked_events[event_id]]
if len(self.tracked_events.get(event_id, [])) > 0
else [thumbnail]
)
)
# Generate the description. Call happens in a thread since it is network bound.
threading.Thread(
target=self._embed_description,
name=f"_embed_description_{event.id}",
daemon=True,
args=(
event,
embed_image,
),
).start()
# Delete tracked events based on the event_id
if event_id in self.tracked_events:
del self.tracked_events[event_id]
def _process_event_metadata(self):
# Check for regenerate description requests
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update(
timeout=1
)
if topic is None:
return
if event_id:
self.handle_regenerate_description(event_id, source)
def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]:
"""Return jpg thumbnail of a region of the frame."""
frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420)
region = calculate_region(
frame.shape, box[0], box[1], box[2], box[3], height, multiplier=1.4
)
frame = frame[region[1] : region[3], region[0] : region[2]]
width = int(height * frame.shape[1] / frame.shape[0])
frame = cv2.resize(frame, dsize=(width, height), interpolation=cv2.INTER_AREA)
ret, jpg = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
if ret:
return jpg.tobytes()
return None
def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
"""Embed the thumbnail for an event."""
self.embeddings.upsert_thumbnail(event_id, thumbnail)
def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None:
"""Embed the description for an event."""
camera_config = self.config.cameras[event.camera]
description = self.genai_client.generate_description(
camera_config, thumbnails, event.label
)
if not description:
logger.debug("Failed to generate description for %s", event.id)
return
# fire and forget description update
self.requestor.send_data(
UPDATE_EVENT_DESCRIPTION,
{"id": event.id, "description": description},
)
# Encode the description
self.embeddings.upsert_description(event.id, description)
logger.debug(
"Generated description for %s (%d images): %s",
event.id,
len(thumbnails),
description,
)
def handle_regenerate_description(self, event_id: str, source: str) -> None:
try:
event: Event = Event.get(Event.id == event_id)
except DoesNotExist:
logger.error(f"Event {event_id} not found for description regeneration")
return
camera_config = self.config.cameras[event.camera]
if not camera_config.genai.enabled or self.genai_client is None:
logger.error(f"GenAI not enabled for camera {event.camera}")
return
thumbnail = base64.b64decode(event.thumbnail)
logger.debug(
f"Trying {source} regeneration for {event}, has_snapshot: {event.has_snapshot}"
)
if event.has_snapshot and source == "snapshot":
with open(
os.path.join(CLIPS_DIR, f"{event.camera}-{event.id}.jpg"),
"rb",
) as image_file:
snapshot_image = image_file.read()
img = cv2.imdecode(
np.frombuffer(snapshot_image, dtype=np.int8), cv2.IMREAD_COLOR
)
# crop snapshot based on region before sending off to genai
height, width = img.shape[:2]
x1_rel, y1_rel, width_rel, height_rel = event.data["region"]
x1, y1 = int(x1_rel * width), int(y1_rel * height)
cropped_image = img[
y1 : y1 + int(height_rel * height), x1 : x1 + int(width_rel * width)
]
_, buffer = cv2.imencode(".jpg", cropped_image)
snapshot_image = buffer.tobytes()
embed_image = (
[snapshot_image]
if event.has_snapshot and source == "snapshot"
else (
[thumbnail for data in self.tracked_events[event_id]]
if len(self.tracked_events.get(event_id, [])) > 0
else [thumbnail]
)
)
self._embed_description(event, embed_image)