Use sqlite-vec extension instead of chromadb for embeddings (#14163)

* swap sqlite_vec for chroma in requirements

* load sqlite_vec in embeddings manager

* remove chroma and revamp Embeddings class for sqlite_vec

* manual minilm onnx inference

* remove chroma in clip model

* migrate api from chroma to sqlite_vec

* migrate event cleanup from chroma to sqlite_vec

* migrate embedding maintainer from chroma to sqlite_vec

* genai description for sqlite_vec

* load sqlite_vec in main thread db

* extend the SqliteQueueDatabase class and use peewee db.execute_sql

* search with Event type for similarity

* fix similarity search

* install and add comment about transformers

* fix normalization

* add id filter

* clean up

* clean up

* fully remove chroma and add transformers env var

* readd uvicorn for fastapi

* readd tokenizer parallelism env var

* remove chroma from docs

* remove chroma from UI

* try removing custom pysqlite3 build

* hard code limit

* optimize queries

* revert explore query

* fix query

* keep building pysqlite3

* single pass fetch and process

* remove unnecessary re-embed

* update deps

* move SqliteVecQueueDatabase to db directory

* make search thumbnail take up full size of results box

* improve typing

* improve model downloading and add status screen

* daemon downloading thread

* catch case when semantic search is disabled

* fix typing

* build sqlite_vec from source

* resolve conflict

* file permissions

* try build deps

* remove sources

* sources

* fix thread start

* include git in build

* reorder embeddings after detectors are started

* build with sqlite amalgamation

* non-platform specific

* use wget instead of curl

* remove unzip -d

* remove sqlite_vec from requirements and load the compiled version

* fix build

* avoid race in db connection

* add scale_factor and bias to description zscore normalization
This commit is contained in:
Josh Hawkins 2024-10-07 15:30:45 -05:00 committed by GitHub
parent 757150dec1
commit 24ac9f3e5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 951 additions and 533 deletions

View File

@ -30,6 +30,16 @@ RUN --mount=type=tmpfs,target=/tmp --mount=type=tmpfs,target=/var/cache/apt \
--mount=type=cache,target=/root/.ccache \ --mount=type=cache,target=/root/.ccache \
/deps/build_nginx.sh /deps/build_nginx.sh
FROM wget AS sqlite-vec
ARG DEBIAN_FRONTEND
# Build sqlite_vec from source
COPY docker/main/build_sqlite_vec.sh /deps/build_sqlite_vec.sh
RUN --mount=type=tmpfs,target=/tmp --mount=type=tmpfs,target=/var/cache/apt \
--mount=type=bind,source=docker/main/build_sqlite_vec.sh,target=/deps/build_sqlite_vec.sh \
--mount=type=cache,target=/root/.ccache \
/deps/build_sqlite_vec.sh
FROM scratch AS go2rtc FROM scratch AS go2rtc
ARG TARGETARCH ARG TARGETARCH
WORKDIR /rootfs/usr/local/go2rtc/bin WORKDIR /rootfs/usr/local/go2rtc/bin
@ -163,7 +173,7 @@ RUN wget -q https://bootstrap.pypa.io/get-pip.py -O get-pip.py \
COPY docker/main/requirements.txt /requirements.txt COPY docker/main/requirements.txt /requirements.txt
RUN pip3 install -r /requirements.txt RUN pip3 install -r /requirements.txt
# Build pysqlite3 from source to support ChromaDB # Build pysqlite3 from source
COPY docker/main/build_pysqlite3.sh /build_pysqlite3.sh COPY docker/main/build_pysqlite3.sh /build_pysqlite3.sh
RUN /build_pysqlite3.sh RUN /build_pysqlite3.sh
@ -177,6 +187,7 @@ RUN pip3 wheel --no-deps --wheel-dir=/wheels-post -r /requirements-wheels-post.t
# Collect deps in a single layer # Collect deps in a single layer
FROM scratch AS deps-rootfs FROM scratch AS deps-rootfs
COPY --from=nginx /usr/local/nginx/ /usr/local/nginx/ COPY --from=nginx /usr/local/nginx/ /usr/local/nginx/
COPY --from=sqlite-vec /usr/local/lib/ /usr/local/lib/
COPY --from=go2rtc /rootfs/ / COPY --from=go2rtc /rootfs/ /
COPY --from=libusb-build /usr/local/lib /usr/local/lib COPY --from=libusb-build /usr/local/lib /usr/local/lib
COPY --from=tempio /rootfs/ / COPY --from=tempio /rootfs/ /
@ -197,12 +208,11 @@ ARG APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn
ENV NVIDIA_VISIBLE_DEVICES=all ENV NVIDIA_VISIBLE_DEVICES=all
ENV NVIDIA_DRIVER_CAPABILITIES="compute,video,utility" ENV NVIDIA_DRIVER_CAPABILITIES="compute,video,utility"
# Turn off Chroma Telemetry: https://docs.trychroma.com/telemetry#opting-out
ENV ANONYMIZED_TELEMETRY=False
# Allow resetting the chroma database
ENV ALLOW_RESET=True
# Disable tokenizer parallelism warning # Disable tokenizer parallelism warning
# https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning/72926996#72926996
ENV TOKENIZERS_PARALLELISM=true ENV TOKENIZERS_PARALLELISM=true
# https://github.com/huggingface/transformers/issues/27214
ENV TRANSFORMERS_NO_ADVISORY_WARNINGS=1
ENV PATH="/usr/local/go2rtc/bin:/usr/local/tempio/bin:/usr/local/nginx/sbin:${PATH}" ENV PATH="/usr/local/go2rtc/bin:/usr/local/tempio/bin:/usr/local/nginx/sbin:${PATH}"
ENV LIBAVFORMAT_VERSION_MAJOR=60 ENV LIBAVFORMAT_VERSION_MAJOR=60

31
docker/main/build_sqlite_vec.sh Executable file
View File

@ -0,0 +1,31 @@
#!/bin/bash
set -euxo pipefail
SQLITE_VEC_VERSION="0.1.3"
cp /etc/apt/sources.list /etc/apt/sources.list.d/sources-src.list
sed -i 's|deb http|deb-src http|g' /etc/apt/sources.list.d/sources-src.list
apt-get update
apt-get -yqq build-dep sqlite3 gettext git
mkdir /tmp/sqlite_vec
# Grab the sqlite_vec source code.
wget -nv https://github.com/asg017/sqlite-vec/archive/refs/tags/v${SQLITE_VEC_VERSION}.tar.gz
tar -zxf v${SQLITE_VEC_VERSION}.tar.gz -C /tmp/sqlite_vec
cd /tmp/sqlite_vec/sqlite-vec-${SQLITE_VEC_VERSION}
mkdir -p vendor
wget -O sqlite-amalgamation.zip https://www.sqlite.org/2024/sqlite-amalgamation-3450300.zip
unzip sqlite-amalgamation.zip
mv sqlite-amalgamation-3450300/* vendor/
rmdir sqlite-amalgamation-3450300
rm sqlite-amalgamation.zip
# build loadable module
make loadable
# install it
cp dist/vec0.* /usr/local/lib

View File

@ -2,6 +2,7 @@ click == 8.1.*
# FastAPI # FastAPI
starlette-context == 0.3.6 starlette-context == 0.3.6
fastapi == 0.115.0 fastapi == 0.115.0
uvicorn == 0.30.*
slowapi == 0.1.9 slowapi == 0.1.9
imutils == 0.5.* imutils == 0.5.*
joserfc == 1.0.* joserfc == 1.0.*
@ -32,12 +33,12 @@ unidecode == 1.3.*
# OpenVino (ONNX installed in wheels-post) # OpenVino (ONNX installed in wheels-post)
openvino == 2024.3.* openvino == 2024.3.*
# Embeddings # Embeddings
chromadb == 0.5.7 transformers == 4.45.*
onnx_clip == 4.0.* onnx_clip == 4.0.*
# Generative AI # Generative AI
google-generativeai == 0.6.* google-generativeai == 0.8.*
ollama == 0.2.* ollama == 0.3.*
openai == 1.30.* openai == 1.51.*
# push notifications # push notifications
py-vapid == 1.9.* py-vapid == 1.9.*
pywebpush == 2.0.* pywebpush == 2.0.*

View File

@ -1 +0,0 @@
chroma-pipeline

View File

@ -1,4 +0,0 @@
#!/command/with-contenv bash
# shellcheck shell=bash
exec logutil-service /dev/shm/logs/chroma

View File

@ -1,28 +0,0 @@
#!/command/with-contenv bash
# shellcheck shell=bash
# Take down the S6 supervision tree when the service exits
set -o errexit -o nounset -o pipefail
# Logs should be sent to stdout so that s6 can collect them
declare exit_code_container
exit_code_container=$(cat /run/s6-linux-init-container-results/exitcode)
readonly exit_code_container
readonly exit_code_service="${1}"
readonly exit_code_signal="${2}"
readonly service="ChromaDB"
echo "[INFO] Service ${service} exited with code ${exit_code_service} (by signal ${exit_code_signal})"
if [[ "${exit_code_service}" -eq 256 ]]; then
if [[ "${exit_code_container}" -eq 0 ]]; then
echo $((128 + exit_code_signal)) >/run/s6-linux-init-container-results/exitcode
fi
elif [[ "${exit_code_service}" -ne 0 ]]; then
if [[ "${exit_code_container}" -eq 0 ]]; then
echo "${exit_code_service}" >/run/s6-linux-init-container-results/exitcode
fi
fi
exec /run/s6/basedir/bin/halt

View File

@ -1,27 +0,0 @@
#!/command/with-contenv bash
# shellcheck shell=bash
# Start the Frigate service
set -o errexit -o nounset -o pipefail
# Logs should be sent to stdout so that s6 can collect them
# Tell S6-Overlay not to restart this service
s6-svc -O .
search_enabled=`python3 /usr/local/semantic_search/get_search_settings.py | jq -r .enabled`
# Replace the bash process with the Frigate process, redirecting stderr to stdout
exec 2>&1
if [[ "$search_enabled" == 'true' ]]; then
echo "[INFO] Starting ChromaDB..."
exec /usr/local/chroma run --path /config/chroma --host 127.0.0.1
else
while true
do
sleep 9999
continue
done
exit 0
fi

View File

@ -1 +0,0 @@
longrun

View File

@ -4,7 +4,7 @@
set -o errexit -o nounset -o pipefail set -o errexit -o nounset -o pipefail
dirs=(/dev/shm/logs/frigate /dev/shm/logs/go2rtc /dev/shm/logs/nginx /dev/shm/logs/certsync /dev/shm/logs/chroma) dirs=(/dev/shm/logs/frigate /dev/shm/logs/go2rtc /dev/shm/logs/nginx /dev/shm/logs/certsync)
mkdir -p "${dirs[@]}" mkdir -p "${dirs[@]}"
chown nobody:nogroup "${dirs[@]}" chown nobody:nogroup "${dirs[@]}"

View File

@ -1,14 +0,0 @@
#!/usr/bin/python3
# -*- coding: utf-8 -*-s
__import__("pysqlite3")
import re
import sys
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
from chromadb.cli.cli import app
if __name__ == "__main__":
sys.argv[0] = re.sub(r"(-script\.pyw|\.exe)?$", "", sys.argv[0])
sys.exit(app())

View File

@ -1,30 +0,0 @@
"""Prints the semantic_search config as json to stdout."""
import json
import os
from ruamel.yaml import YAML
yaml = YAML()
config_file = os.environ.get("CONFIG_FILE", "/config/config.yml")
# Check if we can use .yaml instead of .yml
config_file_yaml = config_file.replace(".yml", ".yaml")
if os.path.isfile(config_file_yaml):
config_file = config_file_yaml
try:
with open(config_file) as f:
raw_config = f.read()
if config_file.endswith((".yaml", ".yml")):
config: dict[str, any] = yaml.load(raw_config)
elif config_file.endswith(".json"):
config: dict[str, any] = json.loads(raw_config)
except FileNotFoundError:
config: dict[str, any] = {}
search_config: dict[str, any] = config.get("semantic_search", {"enabled": False})
print(json.dumps(search_config))

View File

@ -5,7 +5,7 @@ title: Using Semantic Search
Semantic Search in Frigate allows you to find tracked objects within your review items using either the image itself, a user-defined text description, or an automatically generated one. This feature works by creating _embeddings_ — numerical vector representations — for both the images and text descriptions of your tracked objects. By comparing these embeddings, Frigate assesses their similarities to deliver relevant search results. Semantic Search in Frigate allows you to find tracked objects within your review items using either the image itself, a user-defined text description, or an automatically generated one. This feature works by creating _embeddings_ — numerical vector representations — for both the images and text descriptions of your tracked objects. By comparing these embeddings, Frigate assesses their similarities to deliver relevant search results.
Frigate has support for two models to create embeddings, both of which run locally: [OpenAI CLIP](https://openai.com/research/clip) and [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2). Embeddings are then saved to a local instance of [ChromaDB](https://trychroma.com). Frigate has support for two models to create embeddings, both of which run locally: [OpenAI CLIP](https://openai.com/research/clip) and [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2). Embeddings are then saved to Frigate's database.
Semantic Search is accessed via the _Explore_ view in the Frigate UI. Semantic Search is accessed via the _Explore_ view in the Frigate UI.
@ -29,7 +29,7 @@ If you are enabling the Search feature for the first time, be advised that Friga
### OpenAI CLIP ### OpenAI CLIP
This model is able to embed both images and text into the same vector space, which allows `image -> image` and `text -> image` similarity searches. Frigate uses this model on tracked objects to encode the thumbnail image and store it in Chroma. When searching for tracked objects via text in the search box, Frigate will perform a `text -> image` similarity search against this embedding. When clicking "Find Similar" in the tracked object detail pane, Frigate will perform an `image -> image` similarity search to retrieve the closest matching thumbnails. This model is able to embed both images and text into the same vector space, which allows `image -> image` and `text -> image` similarity searches. Frigate uses this model on tracked objects to encode the thumbnail image and store it in the database. When searching for tracked objects via text in the search box, Frigate will perform a `text -> image` similarity search against this embedding. When clicking "Find Similar" in the tracked object detail pane, Frigate will perform an `image -> image` similarity search to retrieve the closest matching thumbnails.
### all-MiniLM-L6-v2 ### all-MiniLM-L6-v2

View File

@ -384,12 +384,12 @@ def vainfo():
@router.get("/logs/{service}", tags=[Tags.logs]) @router.get("/logs/{service}", tags=[Tags.logs])
def logs( def logs(
service: str = Path(enum=["frigate", "nginx", "go2rtc", "chroma"]), service: str = Path(enum=["frigate", "nginx", "go2rtc"]),
download: Optional[str] = None, download: Optional[str] = None,
start: Optional[int] = 0, start: Optional[int] = 0,
end: Optional[int] = None, end: Optional[int] = None,
): ):
"""Get logs for the requested service (frigate/nginx/go2rtc/chroma)""" """Get logs for the requested service (frigate/nginx/go2rtc)"""
def download_logs(service_location: str): def download_logs(service_location: str):
try: try:
@ -408,7 +408,6 @@ def logs(
"frigate": "/dev/shm/logs/frigate/current", "frigate": "/dev/shm/logs/frigate/current",
"go2rtc": "/dev/shm/logs/go2rtc/current", "go2rtc": "/dev/shm/logs/go2rtc/current",
"nginx": "/dev/shm/logs/nginx/current", "nginx": "/dev/shm/logs/nginx/current",
"chroma": "/dev/shm/logs/chroma/current",
} }
service_location = log_locations.get(service) service_location = log_locations.get(service)

View File

@ -1,8 +1,6 @@
"""Event apis.""" """Event apis."""
import base64
import datetime import datetime
import io
import logging import logging
import os import os
from functools import reduce from functools import reduce
@ -10,12 +8,10 @@ from pathlib import Path
from urllib.parse import unquote from urllib.parse import unquote
import cv2 import cv2
import numpy as np
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.params import Depends from fastapi.params import Depends
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from peewee import JOIN, DoesNotExist, fn, operator from peewee import JOIN, DoesNotExist, fn, operator
from PIL import Image
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from frigate.api.defs.events_body import ( from frigate.api.defs.events_body import (
@ -39,7 +35,6 @@ from frigate.const import (
CLIPS_DIR, CLIPS_DIR,
) )
from frigate.embeddings import EmbeddingsContext from frigate.embeddings import EmbeddingsContext
from frigate.embeddings.embeddings import get_metadata
from frigate.models import Event, ReviewSegment, Timeline from frigate.models import Event, ReviewSegment, Timeline
from frigate.object_processing import TrackedObject from frigate.object_processing import TrackedObject
from frigate.util.builtin import get_tz_modifiers from frigate.util.builtin import get_tz_modifiers
@ -411,16 +406,12 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
event_filters = [] event_filters = []
if cameras != "all": if cameras != "all":
camera_list = cameras.split(",") event_filters.append((Event.camera << cameras.split(",")))
event_filters.append((Event.camera << camera_list))
if labels != "all": if labels != "all":
label_list = labels.split(",") event_filters.append((Event.label << labels.split(",")))
event_filters.append((Event.label << label_list))
if zones != "all": if zones != "all":
# use matching so events with multiple zones
# still match on a search where any zone matches
zone_clauses = [] zone_clauses = []
filtered_zones = zones.split(",") filtered_zones = zones.split(",")
@ -431,8 +422,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
for zone in filtered_zones: for zone in filtered_zones:
zone_clauses.append((Event.zones.cast("text") % f'*"{zone}"*')) zone_clauses.append((Event.zones.cast("text") % f'*"{zone}"*'))
zone_clause = reduce(operator.or_, zone_clauses) event_filters.append((reduce(operator.or_, zone_clauses)))
event_filters.append((zone_clause))
if after: if after:
event_filters.append((Event.start_time > after)) event_filters.append((Event.start_time > after))
@ -441,13 +431,11 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
event_filters.append((Event.start_time < before)) event_filters.append((Event.start_time < before))
if time_range != DEFAULT_TIME_RANGE: if time_range != DEFAULT_TIME_RANGE:
# get timezone arg to ensure browser times are used
tz_name = params.timezone tz_name = params.timezone
hour_modifier, minute_modifier, _ = get_tz_modifiers(tz_name) hour_modifier, minute_modifier, _ = get_tz_modifiers(tz_name)
times = time_range.split(",") times = time_range.split(",")
time_after = times[0] time_after, time_before = times
time_before = times[1]
start_hour_fun = fn.strftime( start_hour_fun = fn.strftime(
"%H:%M", "%H:%M",
@ -470,132 +458,113 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
event_filters.append((start_hour_fun > time_after)) event_filters.append((start_hour_fun > time_after))
event_filters.append((start_hour_fun < time_before)) event_filters.append((start_hour_fun < time_before))
if event_filters: # Perform semantic search
filtered_event_ids = ( search_results = {}
Event.select(Event.id)
.where(reduce(operator.and_, event_filters))
.tuples()
.iterator()
)
event_ids = [event_id[0] for event_id in filtered_event_ids]
if not event_ids:
return JSONResponse(content=[]) # No events to search on
else:
event_ids = []
# Build the Chroma where clause based on the event IDs
where = {"id": {"$in": event_ids}} if event_ids else {}
thumb_ids = {}
desc_ids = {}
if search_type == "similarity": if search_type == "similarity":
# Grab the ids of events that match the thumbnail image embeddings
try: try:
search_event: Event = Event.get(Event.id == event_id) search_event: Event = Event.get(Event.id == event_id)
except DoesNotExist: except DoesNotExist:
return JSONResponse( return JSONResponse(
content=( content={
{ "success": False,
"success": False, "message": "Event not found",
"message": "Event not found", },
}
),
status_code=404, status_code=404,
) )
thumbnail = base64.b64decode(search_event.thumbnail)
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB")) thumb_result = context.embeddings.search_thumbnail(search_event)
thumb_result = context.embeddings.thumbnail.query(
query_images=[img],
n_results=limit,
where=where,
)
thumb_ids = dict( thumb_ids = dict(
zip( zip(
thumb_result["ids"][0], [result[0] for result in thumb_result],
context.thumb_stats.normalize(thumb_result["distances"][0]), context.thumb_stats.normalize([result[1] for result in thumb_result]),
) )
) )
search_results = {
event_id: {"distance": distance, "source": "thumbnail"}
for event_id, distance in thumb_ids.items()
}
else: else:
search_types = search_type.split(",") search_types = search_type.split(",")
if "thumbnail" in search_types: if "thumbnail" in search_types:
thumb_result = context.embeddings.thumbnail.query( thumb_result = context.embeddings.search_thumbnail(query)
query_texts=[query],
n_results=limit,
where=where,
)
# Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM.
thumb_ids = dict( thumb_ids = dict(
zip( zip(
thumb_result["ids"][0], [result[0] for result in thumb_result],
context.thumb_stats.normalize(thumb_result["distances"][0]), context.thumb_stats.normalize(
[result[1] for result in thumb_result]
),
) )
) )
search_results.update(
{
event_id: {"distance": distance, "source": "thumbnail"}
for event_id, distance in thumb_ids.items()
}
)
if "description" in search_types: if "description" in search_types:
desc_result = context.embeddings.description.query( desc_result = context.embeddings.search_description(query)
query_texts=[query],
n_results=limit,
where=where,
)
desc_ids = dict( desc_ids = dict(
zip( zip(
desc_result["ids"][0], [result[0] for result in desc_result],
context.desc_stats.normalize(desc_result["distances"][0]), context.desc_stats.normalize([result[1] for result in desc_result]),
) )
) )
for event_id, distance in desc_ids.items():
if (
event_id not in search_results
or distance < search_results[event_id]["distance"]
):
search_results[event_id] = {
"distance": distance,
"source": "description",
}
results = {} if not search_results:
for event_id in thumb_ids.keys() | desc_ids:
min_distance = min(
i
for i in (thumb_ids.get(event_id), desc_ids.get(event_id))
if i is not None
)
results[event_id] = {
"distance": min_distance,
"source": "thumbnail"
if min_distance == thumb_ids.get(event_id)
else "description",
}
if not results:
return JSONResponse(content=[]) return JSONResponse(content=[])
# Get the event data # Fetch events in a single query
events = ( events_query = Event.select(*selected_columns).join(
Event.select(*selected_columns) ReviewSegment,
.join( JOIN.LEFT_OUTER,
ReviewSegment, on=(fn.json_extract(ReviewSegment.data, "$.detections").contains(Event.id)),
JOIN.LEFT_OUTER,
on=(fn.json_extract(ReviewSegment.data, "$.detections").contains(Event.id)),
)
.where(Event.id << list(results.keys()))
.dicts()
.iterator()
) )
events = list(events)
events = [ # Apply filters, if any
{k: v for k, v in event.items() if k != "data"} if event_filters:
| { events_query = events_query.where(reduce(operator.and_, event_filters))
"data": {
k: v
for k, v in event["data"].items()
if k in ["type", "score", "top_score", "description"]
}
}
| {
"search_distance": results[event["id"]]["distance"],
"search_source": results[event["id"]]["source"],
}
for event in events
]
events = sorted(events, key=lambda x: x["search_distance"])[:limit]
return JSONResponse(content=events) # If we did a similarity search, limit events to those in search_results
if search_results:
events_query = events_query.where(Event.id << list(search_results.keys()))
# Fetch events and process them in a single pass
processed_events = []
for event in events_query.dicts():
processed_event = {k: v for k, v in event.items() if k != "data"}
processed_event["data"] = {
k: v
for k, v in event["data"].items()
if k in ["type", "score", "top_score", "description"]
}
if event["id"] in search_results:
processed_event["search_distance"] = search_results[event["id"]]["distance"]
processed_event["search_source"] = search_results[event["id"]]["source"]
processed_events.append(processed_event)
# Sort by search distance if search_results are available, otherwise by start_time
if search_results:
processed_events.sort(key=lambda x: x.get("search_distance", float("inf")))
else:
processed_events.sort(key=lambda x: x["start_time"], reverse=True)
# Limit the number of events returned
processed_events = processed_events[:limit]
return JSONResponse(content=processed_events)
@router.get("/events/summary") @router.get("/events/summary")
@ -975,10 +944,9 @@ def set_description(
# If semantic search is enabled, update the index # If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled: if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings context: EmbeddingsContext = request.app.embeddings
context.embeddings.description.upsert( context.embeddings.upsert_description(
documents=[new_description], event_id=event_id,
metadatas=[get_metadata(event)], description=new_description,
ids=[event_id],
) )
response_message = ( response_message = (
@ -1065,8 +1033,8 @@ def delete_event(request: Request, event_id: str):
# If semantic search is enabled, update the index # If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled: if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings context: EmbeddingsContext = request.app.embeddings
context.embeddings.thumbnail.delete(ids=[event_id]) context.embeddings.delete_thumbnail(id=[event_id])
context.embeddings.description.delete(ids=[event_id]) context.embeddings.delete_description(id=[event_id])
return JSONResponse( return JSONResponse(
content=({"success": True, "message": "Event " + event_id + " deleted"}), content=({"success": True, "message": "Event " + event_id + " deleted"}),
status_code=200, status_code=200,

View File

@ -12,7 +12,6 @@ import psutil
import uvicorn import uvicorn
from peewee_migrate import Router from peewee_migrate import Router
from playhouse.sqlite_ext import SqliteExtDatabase from playhouse.sqlite_ext import SqliteExtDatabase
from playhouse.sqliteq import SqliteQueueDatabase
import frigate.util as util import frigate.util as util
from frigate.api.auth import hash_password from frigate.api.auth import hash_password
@ -38,6 +37,7 @@ from frigate.const import (
MODEL_CACHE_DIR, MODEL_CACHE_DIR,
RECORD_DIR, RECORD_DIR,
) )
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.embeddings import EmbeddingsContext, manage_embeddings from frigate.embeddings import EmbeddingsContext, manage_embeddings
from frigate.events.audio import AudioProcessor from frigate.events.audio import AudioProcessor
from frigate.events.cleanup import EventCleanup from frigate.events.cleanup import EventCleanup
@ -88,6 +88,7 @@ class FrigateApp:
self.camera_metrics: dict[str, CameraMetrics] = {} self.camera_metrics: dict[str, CameraMetrics] = {}
self.ptz_metrics: dict[str, PTZMetrics] = {} self.ptz_metrics: dict[str, PTZMetrics] = {}
self.processes: dict[str, int] = {} self.processes: dict[str, int] = {}
self.embeddings: Optional[EmbeddingsContext] = None
self.region_grids: dict[str, list[list[dict[str, int]]]] = {} self.region_grids: dict[str, list[list[dict[str, int]]]] = {}
self.config = config self.config = config
@ -220,11 +221,8 @@ class FrigateApp:
def init_embeddings_manager(self) -> None: def init_embeddings_manager(self) -> None:
if not self.config.semantic_search.enabled: if not self.config.semantic_search.enabled:
self.embeddings = None
return return
# Create a client for other processes to use
self.embeddings = EmbeddingsContext()
embedding_process = util.Process( embedding_process = util.Process(
target=manage_embeddings, target=manage_embeddings,
name="embeddings_manager", name="embeddings_manager",
@ -239,7 +237,7 @@ class FrigateApp:
def bind_database(self) -> None: def bind_database(self) -> None:
"""Bind db to the main process.""" """Bind db to the main process."""
# NOTE: all db accessing processes need to be created before the db can be bound to the main process # NOTE: all db accessing processes need to be created before the db can be bound to the main process
self.db = SqliteQueueDatabase( self.db = SqliteVecQueueDatabase(
self.config.database.path, self.config.database.path,
pragmas={ pragmas={
"auto_vacuum": "FULL", # Does not defragment database "auto_vacuum": "FULL", # Does not defragment database
@ -249,6 +247,7 @@ class FrigateApp:
timeout=max( timeout=max(
60, 10 * len([c for c in self.config.cameras.values() if c.enabled]) 60, 10 * len([c for c in self.config.cameras.values() if c.enabled])
), ),
load_vec_extension=self.config.semantic_search.enabled,
) )
models = [ models = [
Event, Event,
@ -274,6 +273,11 @@ class FrigateApp:
migrate_exports(self.config.ffmpeg, list(self.config.cameras.keys())) migrate_exports(self.config.ffmpeg, list(self.config.cameras.keys()))
def init_embeddings_client(self) -> None:
if self.config.semantic_search.enabled:
# Create a client for other processes to use
self.embeddings = EmbeddingsContext(self.db)
def init_external_event_processor(self) -> None: def init_external_event_processor(self) -> None:
self.external_event_processor = ExternalEventProcessor(self.config) self.external_event_processor = ExternalEventProcessor(self.config)
@ -464,7 +468,7 @@ class FrigateApp:
self.event_processor.start() self.event_processor.start()
def start_event_cleanup(self) -> None: def start_event_cleanup(self) -> None:
self.event_cleanup = EventCleanup(self.config, self.stop_event) self.event_cleanup = EventCleanup(self.config, self.stop_event, self.db)
self.event_cleanup.start() self.event_cleanup.start()
def start_record_cleanup(self) -> None: def start_record_cleanup(self) -> None:
@ -576,13 +580,14 @@ class FrigateApp:
self.init_onvif() self.init_onvif()
self.init_recording_manager() self.init_recording_manager()
self.init_review_segment_manager() self.init_review_segment_manager()
self.init_embeddings_manager()
self.init_go2rtc() self.init_go2rtc()
self.bind_database() self.bind_database()
self.check_db_data_migrations() self.check_db_data_migrations()
self.init_inter_process_communicator() self.init_inter_process_communicator()
self.init_dispatcher() self.init_dispatcher()
self.start_detectors() self.start_detectors()
self.init_embeddings_manager()
self.init_embeddings_client()
self.start_video_output_processor() self.start_video_output_processor()
self.start_ptz_autotracker() self.start_ptz_autotracker()
self.init_historical_regions() self.init_historical_regions()

View File

@ -16,10 +16,12 @@ from frigate.const import (
REQUEST_REGION_GRID, REQUEST_REGION_GRID,
UPDATE_CAMERA_ACTIVITY, UPDATE_CAMERA_ACTIVITY,
UPDATE_EVENT_DESCRIPTION, UPDATE_EVENT_DESCRIPTION,
UPDATE_MODEL_STATE,
UPSERT_REVIEW_SEGMENT, UPSERT_REVIEW_SEGMENT,
) )
from frigate.models import Event, Previews, Recordings, ReviewSegment from frigate.models import Event, Previews, Recordings, ReviewSegment
from frigate.ptz.onvif import OnvifCommandEnum, OnvifController from frigate.ptz.onvif import OnvifCommandEnum, OnvifController
from frigate.types import ModelStatusTypesEnum
from frigate.util.object import get_camera_regions_grid from frigate.util.object import get_camera_regions_grid
from frigate.util.services import restart_frigate from frigate.util.services import restart_frigate
@ -83,6 +85,7 @@ class Dispatcher:
comm.subscribe(self._receive) comm.subscribe(self._receive)
self.camera_activity = {} self.camera_activity = {}
self.model_state = {}
def _receive(self, topic: str, payload: str) -> Optional[Any]: def _receive(self, topic: str, payload: str) -> Optional[Any]:
"""Handle receiving of payload from communicators.""" """Handle receiving of payload from communicators."""
@ -144,6 +147,14 @@ class Dispatcher:
"event_update", "event_update",
json.dumps({"id": event.id, "description": event.data["description"]}), json.dumps({"id": event.id, "description": event.data["description"]}),
) )
elif topic == UPDATE_MODEL_STATE:
model = payload["model"]
state = payload["state"]
self.model_state[model] = ModelStatusTypesEnum[state]
self.publish("model_state", json.dumps(self.model_state))
elif topic == "modelState":
model_state = self.model_state.copy()
self.publish("model_state", json.dumps(model_state))
elif topic == "onConnect": elif topic == "onConnect":
camera_status = self.camera_activity.copy() camera_status = self.camera_activity.copy()

View File

@ -84,6 +84,7 @@ UPSERT_REVIEW_SEGMENT = "upsert_review_segment"
CLEAR_ONGOING_REVIEW_SEGMENTS = "clear_ongoing_review_segments" CLEAR_ONGOING_REVIEW_SEGMENTS = "clear_ongoing_review_segments"
UPDATE_CAMERA_ACTIVITY = "update_camera_activity" UPDATE_CAMERA_ACTIVITY = "update_camera_activity"
UPDATE_EVENT_DESCRIPTION = "update_event_description" UPDATE_EVENT_DESCRIPTION = "update_event_description"
UPDATE_MODEL_STATE = "update_model_state"
# Stats Values # Stats Values

23
frigate/db/sqlitevecq.py Normal file
View File

@ -0,0 +1,23 @@
import sqlite3
from playhouse.sqliteq import SqliteQueueDatabase
class SqliteVecQueueDatabase(SqliteQueueDatabase):
def __init__(self, *args, load_vec_extension: bool = False, **kwargs) -> None:
self.load_vec_extension: bool = load_vec_extension
super().__init__(*args, **kwargs)
# no extension necessary, sqlite will load correctly for each platform
self.sqlite_vec_path = "/usr/local/lib/vec0"
def _connect(self, *args, **kwargs) -> sqlite3.Connection:
conn: sqlite3.Connection = super()._connect(*args, **kwargs)
if self.load_vec_extension:
self._load_vec_extension(conn)
return conn
def _load_vec_extension(self, conn: sqlite3.Connection) -> None:
conn.enable_load_extension(True)
conn.load_extension(self.sqlite_vec_path)
conn.enable_load_extension(False)

View File

@ -1,18 +1,19 @@
"""ChromaDB embeddings database.""" """SQLite-vec embeddings database."""
import json import json
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os
import signal import signal
import threading import threading
from types import FrameType from types import FrameType
from typing import Optional from typing import Optional
from playhouse.sqliteq import SqliteQueueDatabase
from setproctitle import setproctitle from setproctitle import setproctitle
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR from frigate.const import CONFIG_DIR
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event from frigate.models import Event
from frigate.util.services import listen from frigate.util.services import listen
@ -41,7 +42,7 @@ def manage_embeddings(config: FrigateConfig) -> None:
listen() listen()
# Configure Frigate DB # Configure Frigate DB
db = SqliteQueueDatabase( db = SqliteVecQueueDatabase(
config.database.path, config.database.path,
pragmas={ pragmas={
"auto_vacuum": "FULL", # Does not defragment database "auto_vacuum": "FULL", # Does not defragment database
@ -49,17 +50,19 @@ def manage_embeddings(config: FrigateConfig) -> None:
"synchronous": "NORMAL", # Safe when using WAL https://www.sqlite.org/pragma.html#pragma_synchronous "synchronous": "NORMAL", # Safe when using WAL https://www.sqlite.org/pragma.html#pragma_synchronous
}, },
timeout=max(60, 10 * len([c for c in config.cameras.values() if c.enabled])), timeout=max(60, 10 * len([c for c in config.cameras.values() if c.enabled])),
load_vec_extension=True,
) )
models = [Event] models = [Event]
db.bind(models) db.bind(models)
embeddings = Embeddings() embeddings = Embeddings(db)
# Check if we need to re-index events # Check if we need to re-index events
if config.semantic_search.reindex: if config.semantic_search.reindex:
embeddings.reindex() embeddings.reindex()
maintainer = EmbeddingMaintainer( maintainer = EmbeddingMaintainer(
db,
config, config,
stop_event, stop_event,
) )
@ -67,14 +70,14 @@ def manage_embeddings(config: FrigateConfig) -> None:
class EmbeddingsContext: class EmbeddingsContext:
def __init__(self): def __init__(self, db: SqliteVecQueueDatabase):
self.embeddings = Embeddings() self.embeddings = Embeddings(db)
self.thumb_stats = ZScoreNormalization() self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization() self.desc_stats = ZScoreNormalization(scale_factor=2.5, bias=0.5)
# load stats from disk # load stats from disk
try: try:
with open(f"{CONFIG_DIR}/.search_stats.json", "r") as f: with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "r") as f:
data = json.loads(f.read()) data = json.loads(f.read())
self.thumb_stats.from_dict(data["thumb_stats"]) self.thumb_stats.from_dict(data["thumb_stats"])
self.desc_stats.from_dict(data["desc_stats"]) self.desc_stats.from_dict(data["desc_stats"])
@ -87,5 +90,5 @@ class EmbeddingsContext:
"thumb_stats": self.thumb_stats.to_dict(), "thumb_stats": self.thumb_stats.to_dict(),
"desc_stats": self.desc_stats.to_dict(), "desc_stats": self.desc_stats.to_dict(),
} }
with open(f"{CONFIG_DIR}/.search_stats.json", "w") as f: with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
f.write(json.dumps(contents)) json.dump(contents, f)

View File

@ -1,37 +1,23 @@
"""ChromaDB embeddings database.""" """SQLite-vec embeddings database."""
import base64 import base64
import io import io
import logging import logging
import sys import struct
import time import time
from typing import List, Tuple, Union
import numpy as np
from PIL import Image from PIL import Image
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import UPDATE_MODEL_STATE
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event from frigate.models import Event
from frigate.types import ModelStatusTypesEnum
# Squelch posthog logging from .functions.clip import ClipEmbedding
logging.getLogger("chromadb.telemetry.product.posthog").setLevel(logging.CRITICAL) from .functions.minilm_l6_v2 import MiniLMEmbedding
# Hot-swap the sqlite3 module for Chroma compatibility
try:
from chromadb import Collection
from chromadb import HttpClient as ChromaClient
from chromadb.config import Settings
from .functions.clip import ClipEmbedding
from .functions.minilm_l6_v2 import MiniLMEmbedding
except RuntimeError:
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
from chromadb import Collection
from chromadb import HttpClient as ChromaClient
from chromadb.config import Settings
from .functions.clip import ClipEmbedding
from .functions.minilm_l6_v2 import MiniLMEmbedding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -67,34 +53,198 @@ def get_metadata(event: Event) -> dict:
) )
def serialize(vector: List[float]) -> bytes:
"""Serializes a list of floats into a compact "raw bytes" format"""
return struct.pack("%sf" % len(vector), *vector)
def deserialize(bytes_data: bytes) -> List[float]:
"""Deserializes a compact "raw bytes" format into a list of floats"""
return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
class Embeddings: class Embeddings:
"""ChromaDB embeddings database.""" """SQLite-vec embeddings database."""
def __init__(self) -> None: def __init__(self, db: SqliteVecQueueDatabase) -> None:
self.client: ChromaClient = ChromaClient( self.db = db
host="127.0.0.1", self.requestor = InterProcessRequestor()
settings=Settings(anonymized_telemetry=False),
# Create tables if they don't exist
self._create_tables()
models = [
"sentence-transformers/all-MiniLM-L6-v2-model.onnx",
"sentence-transformers/all-MiniLM-L6-v2-tokenizer",
"clip-clip_image_model_vitb32.onnx",
"clip-clip_text_model_vitb32.onnx",
]
for model in models:
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": model,
"state": ModelStatusTypesEnum.not_downloaded,
},
)
self.clip_embedding = ClipEmbedding(
preferred_providers=["CPUExecutionProvider"]
)
self.minilm_embedding = MiniLMEmbedding(
preferred_providers=["CPUExecutionProvider"],
) )
@property def _create_tables(self):
def thumbnail(self) -> Collection: # Create vec0 virtual table for thumbnail embeddings
return self.client.get_or_create_collection( self.db.execute_sql("""
name="event_thumbnail", embedding_function=ClipEmbedding() CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY,
thumbnail_embedding FLOAT[512]
);
""")
# Create vec0 virtual table for description embeddings
self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY,
description_embedding FLOAT[384]
);
""")
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
# Convert thumbnail bytes to PIL Image
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
# Generate embedding using CLIP
embedding = self.clip_embedding([image])[0]
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES(?, ?)
""",
(event_id, serialize(embedding)),
) )
@property return embedding
def description(self) -> Collection:
return self.client.get_or_create_collection( def upsert_description(self, event_id: str, description: str):
name="event_description", # Generate embedding using MiniLM
embedding_function=MiniLMEmbedding( embedding = self.minilm_embedding([description])[0]
preferred_providers=["CPUExecutionProvider"]
), self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
VALUES(?, ?)
""",
(event_id, serialize(embedding)),
) )
return embedding
def delete_thumbnail(self, event_ids: List[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.db.execute_sql(
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
)
def delete_description(self, event_ids: List[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.db.execute_sql(
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
)
def search_thumbnail(
self, query: Union[Event, str], event_ids: List[str] = None
) -> List[Tuple[str, float]]:
if query.__class__ == Event:
cursor = self.db.execute_sql(
"""
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
""",
[query.id],
)
row = cursor.fetchone() if cursor else None
if row:
query_embedding = deserialize(
row[0]
) # Deserialize the thumbnail embedding
else:
# If no embedding found, generate it and return it
thumbnail = base64.b64decode(query.thumbnail)
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
else:
query_embedding = self.clip_embedding([query])[0]
sql_query = """
SELECT
id,
distance
FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def search_description(
self, query_text: str, event_ids: List[str] = None
) -> List[Tuple[str, float]]:
query_embedding = self.minilm_embedding([query_text])[0]
# Prepare the base SQL query
sql_query = """
SELECT
id,
distance
FROM vec_descriptions
WHERE description_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def reindex(self) -> None: def reindex(self) -> None:
"""Reindex all event embeddings."""
logger.info("Indexing event embeddings...") logger.info("Indexing event embeddings...")
self.client.reset()
st = time.time() st = time.time()
totals = { totals = {
@ -115,37 +265,14 @@ class Embeddings:
) )
while len(events) > 0: while len(events) > 0:
thumbnails = {"ids": [], "images": [], "metadatas": []}
descriptions = {"ids": [], "documents": [], "metadatas": []}
event: Event event: Event
for event in events: for event in events:
metadata = get_metadata(event)
thumbnail = base64.b64decode(event.thumbnail) thumbnail = base64.b64decode(event.thumbnail)
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB")) self.upsert_thumbnail(event.id, thumbnail)
thumbnails["ids"].append(event.id) totals["thumb"] += 1
thumbnails["images"].append(img)
thumbnails["metadatas"].append(metadata)
if description := event.data.get("description", "").strip(): if description := event.data.get("description", "").strip():
descriptions["ids"].append(event.id) totals["desc"] += 1
descriptions["documents"].append(description) self.upsert_description(event.id, description)
descriptions["metadatas"].append(metadata)
if len(thumbnails["ids"]) > 0:
totals["thumb"] += len(thumbnails["ids"])
self.thumbnail.upsert(
images=thumbnails["images"],
metadatas=thumbnails["metadatas"],
ids=thumbnails["ids"],
)
if len(descriptions["ids"]) > 0:
totals["desc"] += len(descriptions["ids"])
self.description.upsert(
documents=descriptions["documents"],
metadatas=descriptions["metadatas"],
ids=descriptions["ids"],
)
current_page += 1 current_page += 1
events = ( events = (

View File

@ -1,35 +1,59 @@
"""CLIP Embeddings for Frigate."""
import errno
import logging import logging
import os import os
from pathlib import Path from typing import List, Optional, Union
from typing import Tuple, Union
import numpy as np
import onnxruntime as ort import onnxruntime as ort
import requests from onnx_clip import OnnxClip, Preprocessor, Tokenizer
from chromadb import EmbeddingFunction, Embeddings from PIL import Image
from chromadb.api.types import (
Documents,
Images,
is_document,
is_image,
)
from onnx_clip import OnnxClip
from frigate.const import MODEL_CACHE_DIR from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
logger = logging.getLogger(__name__)
class Clip(OnnxClip): class Clip(OnnxClip):
"""Override load models to download to cache directory.""" """Override load models to use pre-downloaded models from cache directory."""
def __init__(
self,
model: str = "ViT-B/32",
batch_size: Optional[int] = None,
providers: List[str] = ["CPUExecutionProvider"],
):
"""
Instantiates the model and required encoding classes.
Args:
model: The model to utilize. Currently ViT-B/32 and RN50 are
allowed.
batch_size: If set, splits the lists in `get_image_embeddings`
and `get_text_embeddings` into batches of this size before
passing them to the model. The embeddings are then concatenated
back together before being returned. This is necessary when
passing large amounts of data (perhaps ~100 or more).
"""
allowed_models = ["ViT-B/32", "RN50"]
if model not in allowed_models:
raise ValueError(f"`model` must be in {allowed_models}. Got {model}.")
if model == "ViT-B/32":
self.embedding_size = 512
elif model == "RN50":
self.embedding_size = 1024
self.image_model, self.text_model = self._load_models(model, providers)
self._tokenizer = Tokenizer()
self._preprocessor = Preprocessor()
self._batch_size = batch_size
@staticmethod @staticmethod
def _load_models( def _load_models(
model: str, model: str,
silent: bool, providers: List[str],
) -> Tuple[ort.InferenceSession, ort.InferenceSession]: ) -> tuple[ort.InferenceSession, ort.InferenceSession]:
""" """
These models are a part of the container. Treat as as such. Load models from cache directory.
""" """
if model == "ViT-B/32": if model == "ViT-B/32":
IMAGE_MODEL_FILE = "clip_image_model_vitb32.onnx" IMAGE_MODEL_FILE = "clip_image_model_vitb32.onnx"
@ -43,64 +67,100 @@ class Clip(OnnxClip):
models = [] models = []
for model_file in [IMAGE_MODEL_FILE, TEXT_MODEL_FILE]: for model_file in [IMAGE_MODEL_FILE, TEXT_MODEL_FILE]:
path = os.path.join(MODEL_CACHE_DIR, "clip", model_file) path = os.path.join(MODEL_CACHE_DIR, "clip", model_file)
models.append(Clip._load_model(path, silent)) models.append(Clip._load_model(path, providers))
return models[0], models[1] return models[0], models[1]
@staticmethod @staticmethod
def _load_model(path: str, silent: bool): def _load_model(path: str, providers: List[str]):
providers = ["CPUExecutionProvider"] if os.path.exists(path):
try:
if os.path.exists(path):
return ort.InferenceSession(path, providers=providers)
else:
raise FileNotFoundError(
errno.ENOENT,
os.strerror(errno.ENOENT),
path,
)
except Exception:
s3_url = f"https://lakera-clip.s3.eu-west-1.amazonaws.com/{os.path.basename(path)}"
if not silent:
logging.info(
f"The model file ({path}) doesn't exist "
f"or it is invalid. Downloading it from the public S3 "
f"bucket: {s3_url}." # noqa: E501
)
# Download from S3
# Saving to a temporary file first to avoid corrupting the file
temporary_filename = Path(path).with_name(os.path.basename(path) + ".part")
# Create any missing directories in the path
temporary_filename.parent.mkdir(parents=True, exist_ok=True)
with requests.get(s3_url, stream=True) as r:
r.raise_for_status()
with open(temporary_filename, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
f.flush()
# Finally move the temporary file to the correct location
temporary_filename.rename(path)
return ort.InferenceSession(path, providers=providers) return ort.InferenceSession(path, providers=providers)
else:
logger.warning(f"CLIP model file {path} not found.")
return None
class ClipEmbedding(EmbeddingFunction): class ClipEmbedding:
"""Embedding function for CLIP model used in Chroma.""" """Embedding function for CLIP model."""
def __init__(self, model: str = "ViT-B/32"): def __init__(
"""Initialize CLIP Embedding function.""" self,
self.model = Clip(model) model: str = "ViT-B/32",
silent: bool = False,
preferred_providers: List[str] = ["CPUExecutionProvider"],
):
self.model_name = model
self.silent = silent
self.preferred_providers = preferred_providers
self.model_files = self._get_model_files()
self.model = None
def __call__(self, input: Union[Documents, Images]) -> Embeddings: self.downloader = ModelDownloader(
embeddings: Embeddings = [] model_name="clip",
download_path=os.path.join(MODEL_CACHE_DIR, "clip"),
file_names=self.model_files,
download_func=self._download_model,
silent=self.silent,
)
self.downloader.ensure_model_files()
def _get_model_files(self):
if self.model_name == "ViT-B/32":
return ["clip_image_model_vitb32.onnx", "clip_text_model_vitb32.onnx"]
elif self.model_name == "RN50":
return ["clip_image_model_rn50.onnx", "clip_text_model_rn50.onnx"]
else:
raise ValueError(
f"Unexpected model {self.model_name}. No `.onnx` file found."
)
def _download_model(self, path: str):
s3_url = (
f"https://lakera-clip.s3.eu-west-1.amazonaws.com/{os.path.basename(path)}"
)
try:
ModelDownloader.download_from_url(s3_url, path, self.silent)
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{os.path.basename(path)}",
"state": ModelStatusTypesEnum.downloaded,
},
)
except Exception:
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{os.path.basename(path)}",
"state": ModelStatusTypesEnum.error,
},
)
def _load_model(self):
if self.model is None:
self.downloader.wait_for_download()
self.model = Clip(self.model_name, providers=self.preferred_providers)
def __call__(self, input: Union[List[str], List[Image.Image]]) -> List[np.ndarray]:
self._load_model()
if (
self.model is None
or self.model.image_model is None
or self.model.text_model is None
):
logger.info(
"CLIP model is not fully loaded. Please wait for the download to complete."
)
return []
embeddings = []
for item in input: for item in input:
if is_image(item): if isinstance(item, Image.Image):
result = self.model.get_image_embeddings([item]) result = self.model.get_image_embeddings([item])
embeddings.append(result[0, :].tolist()) embeddings.append(result[0])
elif is_document(item): elif isinstance(item, str):
result = self.model.get_text_embeddings([item]) result = self.model.get_text_embeddings([item])
embeddings.append(result[0, :].tolist()) embeddings.append(result[0])
else:
raise ValueError(f"Unsupported input type: {type(item)}")
return embeddings return embeddings

View File

@ -1,11 +1,107 @@
"""Embedding function for ONNX MiniLM-L6 model used in Chroma.""" import logging
import os
from typing import List
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2 import numpy as np
import onnxruntime as ort
from frigate.const import MODEL_CACHE_DIR # importing this without pytorch or others causes a warning
# https://github.com/huggingface/transformers/issues/27214
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
from transformers import AutoTokenizer
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
logger = logging.getLogger(__name__)
class MiniLMEmbedding(ONNXMiniLM_L6_V2): class MiniLMEmbedding:
"""Override DOWNLOAD_PATH to download to cache directory.""" """Embedding function for ONNX MiniLM-L6 model."""
DOWNLOAD_PATH = f"{MODEL_CACHE_DIR}/all-MiniLM-L6-v2" DOWNLOAD_PATH = f"{MODEL_CACHE_DIR}/all-MiniLM-L6-v2"
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
IMAGE_MODEL_FILE = "model.onnx"
TOKENIZER_FILE = "tokenizer"
def __init__(self, preferred_providers=["CPUExecutionProvider"]):
self.preferred_providers = preferred_providers
self.tokenizer = None
self.session = None
self.downloader = ModelDownloader(
model_name=self.MODEL_NAME,
download_path=self.DOWNLOAD_PATH,
file_names=[self.IMAGE_MODEL_FILE, self.TOKENIZER_FILE],
download_func=self._download_model,
)
self.downloader.ensure_model_files()
def _download_model(self, path: str):
try:
if os.path.basename(path) == self.IMAGE_MODEL_FILE:
s3_url = f"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/{self.IMAGE_MODEL_FILE}"
ModelDownloader.download_from_url(s3_url, path)
elif os.path.basename(path) == self.TOKENIZER_FILE:
logger.info("Downloading MiniLM tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
self.MODEL_NAME, clean_up_tokenization_spaces=True
)
tokenizer.save_pretrained(path)
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.MODEL_NAME}-{os.path.basename(path)}",
"state": ModelStatusTypesEnum.downloaded,
},
)
except Exception:
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.MODEL_NAME}-{os.path.basename(path)}",
"state": ModelStatusTypesEnum.error,
},
)
def _load_model_and_tokenizer(self):
if self.tokenizer is None or self.session is None:
self.downloader.wait_for_download()
self.tokenizer = self._load_tokenizer()
self.session = self._load_model(
os.path.join(self.DOWNLOAD_PATH, self.IMAGE_MODEL_FILE),
self.preferred_providers,
)
def _load_tokenizer(self):
tokenizer_path = os.path.join(self.DOWNLOAD_PATH, self.TOKENIZER_FILE)
return AutoTokenizer.from_pretrained(
tokenizer_path, clean_up_tokenization_spaces=True
)
def _load_model(self, path: str, providers: List[str]):
if os.path.exists(path):
return ort.InferenceSession(path, providers=providers)
else:
logger.warning(f"MiniLM model file {path} not found.")
return None
def __call__(self, texts: List[str]) -> List[np.ndarray]:
self._load_model_and_tokenizer()
if self.session is None or self.tokenizer is None:
logger.error("MiniLM model or tokenizer is not loaded.")
return []
inputs = self.tokenizer(
texts, padding=True, truncation=True, return_tensors="np"
)
input_names = [input.name for input in self.session.get_inputs()]
onnx_inputs = {name: inputs[name] for name in input_names if name in inputs}
outputs = self.session.run(None, onnx_inputs)
embeddings = outputs[0].mean(axis=1)
return [embedding for embedding in embeddings]

View File

@ -1,7 +1,6 @@
"""Maintain embeddings in Chroma.""" """Maintain embeddings in SQLite-vec."""
import base64 import base64
import io
import logging import logging
import os import os
import threading import threading
@ -11,7 +10,7 @@ from typing import Optional
import cv2 import cv2
import numpy as np import numpy as np
from peewee import DoesNotExist from peewee import DoesNotExist
from PIL import Image from playhouse.sqliteq import SqliteQueueDatabase
from frigate.comms.event_metadata_updater import ( from frigate.comms.event_metadata_updater import (
EventMetadataSubscriber, EventMetadataSubscriber,
@ -26,7 +25,7 @@ from frigate.genai import get_genai_client
from frigate.models import Event from frigate.models import Event
from frigate.util.image import SharedMemoryFrameManager, calculate_region from frigate.util.image import SharedMemoryFrameManager, calculate_region
from .embeddings import Embeddings, get_metadata from .embeddings import Embeddings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,13 +35,14 @@ class EmbeddingMaintainer(threading.Thread):
def __init__( def __init__(
self, self,
db: SqliteQueueDatabase,
config: FrigateConfig, config: FrigateConfig,
stop_event: MpEvent, stop_event: MpEvent,
) -> None: ) -> None:
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.name = "embeddings_maintainer" self.name = "embeddings_maintainer"
self.config = config self.config = config
self.embeddings = Embeddings() self.embeddings = Embeddings(db)
self.event_subscriber = EventUpdateSubscriber() self.event_subscriber = EventUpdateSubscriber()
self.event_end_subscriber = EventEndSubscriber() self.event_end_subscriber = EventEndSubscriber()
self.event_metadata_subscriber = EventMetadataSubscriber( self.event_metadata_subscriber = EventMetadataSubscriber(
@ -56,7 +56,7 @@ class EmbeddingMaintainer(threading.Thread):
self.genai_client = get_genai_client(config.genai) self.genai_client = get_genai_client(config.genai)
def run(self) -> None: def run(self) -> None:
"""Maintain a Chroma vector database for semantic search.""" """Maintain a SQLite-vec database for semantic search."""
while not self.stop_event.is_set(): while not self.stop_event.is_set():
self._process_updates() self._process_updates()
self._process_finalized() self._process_finalized()
@ -117,12 +117,11 @@ class EmbeddingMaintainer(threading.Thread):
if event.data.get("type") != "object": if event.data.get("type") != "object":
continue continue
# Extract valid event metadata # Extract valid thumbnail
metadata = get_metadata(event)
thumbnail = base64.b64decode(event.thumbnail) thumbnail = base64.b64decode(event.thumbnail)
# Embed the thumbnail # Embed the thumbnail
self._embed_thumbnail(event_id, thumbnail, metadata) self._embed_thumbnail(event_id, thumbnail)
if ( if (
camera_config.genai.enabled camera_config.genai.enabled
@ -183,7 +182,6 @@ class EmbeddingMaintainer(threading.Thread):
args=( args=(
event, event,
embed_image, embed_image,
metadata,
), ),
).start() ).start()
@ -219,25 +217,16 @@ class EmbeddingMaintainer(threading.Thread):
return None return None
def _embed_thumbnail(self, event_id: str, thumbnail: bytes, metadata: dict) -> None: def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
"""Embed the thumbnail for an event.""" """Embed the thumbnail for an event."""
self.embeddings.upsert_thumbnail(event_id, thumbnail)
# Encode the thumbnail def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None:
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
self.embeddings.thumbnail.upsert(
images=[img],
metadatas=[metadata],
ids=[event_id],
)
def _embed_description(
self, event: Event, thumbnails: list[bytes], metadata: dict
) -> None:
"""Embed the description for an event.""" """Embed the description for an event."""
camera_config = self.config.cameras[event.camera] camera_config = self.config.cameras[event.camera]
description = self.genai_client.generate_description( description = self.genai_client.generate_description(
camera_config, thumbnails, metadata camera_config, thumbnails, event.label
) )
if not description: if not description:
@ -251,11 +240,7 @@ class EmbeddingMaintainer(threading.Thread):
) )
# Encode the description # Encode the description
self.embeddings.description.upsert( self.embeddings.upsert_description(event.id, description)
documents=[description],
metadatas=[metadata],
ids=[event.id],
)
logger.debug( logger.debug(
"Generated description for %s (%d images): %s", "Generated description for %s (%d images): %s",
@ -276,7 +261,6 @@ class EmbeddingMaintainer(threading.Thread):
logger.error(f"GenAI not enabled for camera {event.camera}") logger.error(f"GenAI not enabled for camera {event.camera}")
return return
metadata = get_metadata(event)
thumbnail = base64.b64decode(event.thumbnail) thumbnail = base64.b64decode(event.thumbnail)
logger.debug( logger.debug(
@ -315,4 +299,4 @@ class EmbeddingMaintainer(threading.Thread):
) )
) )
self._embed_description(event, embed_image, metadata) self._embed_description(event, embed_image)

View File

@ -4,12 +4,15 @@ import math
class ZScoreNormalization: class ZScoreNormalization:
"""Running Z-score normalization for search distance.""" def __init__(self, scale_factor: float = 1.0, bias: float = 0.0):
"""Initialize with optional scaling and bias adjustments."""
def __init__(self): """scale_factor adjusts the magnitude of each score"""
"""bias will artificially shift the entire distribution upwards"""
self.n = 0 self.n = 0
self.mean = 0 self.mean = 0
self.m2 = 0 self.m2 = 0
self.scale_factor = scale_factor
self.bias = bias
@property @property
def variance(self): def variance(self):
@ -23,7 +26,10 @@ class ZScoreNormalization:
self._update(distances) self._update(distances)
if self.stddev == 0: if self.stddev == 0:
return distances return distances
return [(x - self.mean) / self.stddev for x in distances] return [
(x - self.mean) / self.stddev * self.scale_factor + self.bias
for x in distances
]
def _update(self, distances: list[float]): def _update(self, distances: list[float]):
for x in distances: for x in distances:

View File

@ -8,6 +8,8 @@ from enum import Enum
from multiprocessing.synchronize import Event as MpEvent from multiprocessing.synchronize import Event as MpEvent
from pathlib import Path from pathlib import Path
from playhouse.sqliteq import SqliteQueueDatabase
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.const import CLIPS_DIR from frigate.const import CLIPS_DIR
from frigate.embeddings.embeddings import Embeddings from frigate.embeddings.embeddings import Embeddings
@ -22,16 +24,19 @@ class EventCleanupType(str, Enum):
class EventCleanup(threading.Thread): class EventCleanup(threading.Thread):
def __init__(self, config: FrigateConfig, stop_event: MpEvent): def __init__(
self, config: FrigateConfig, stop_event: MpEvent, db: SqliteQueueDatabase
):
super().__init__(name="event_cleanup") super().__init__(name="event_cleanup")
self.config = config self.config = config
self.stop_event = stop_event self.stop_event = stop_event
self.db = db
self.camera_keys = list(self.config.cameras.keys()) self.camera_keys = list(self.config.cameras.keys())
self.removed_camera_labels: list[str] = None self.removed_camera_labels: list[str] = None
self.camera_labels: dict[str, dict[str, any]] = {} self.camera_labels: dict[str, dict[str, any]] = {}
if self.config.semantic_search.enabled: if self.config.semantic_search.enabled:
self.embeddings = Embeddings() self.embeddings = Embeddings(self.db)
def get_removed_camera_labels(self) -> list[Event]: def get_removed_camera_labels(self) -> list[Event]:
"""Get a list of distinct labels for removed cameras.""" """Get a list of distinct labels for removed cameras."""
@ -229,15 +234,8 @@ class EventCleanup(threading.Thread):
Event.delete().where(Event.id << chunk).execute() Event.delete().where(Event.id << chunk).execute()
if self.config.semantic_search.enabled: if self.config.semantic_search.enabled:
for collection in [ self.embeddings.delete_description(chunk)
self.embeddings.thumbnail, self.embeddings.delete_thumbnail(chunk)
self.embeddings.description, logger.debug(f"Deleted {len(events_to_delete)} embeddings")
]:
existing_ids = collection.get(ids=chunk, include=[])["ids"]
if existing_ids:
collection.delete(ids=existing_ids)
logger.debug(
f"Deleted {len(existing_ids)} embeddings from {collection.__class__.__name__}"
)
logger.info("Exiting event cleanup...") logger.info("Exiting event cleanup...")

View File

@ -31,12 +31,12 @@ class GenAIClient:
self, self,
camera_config: CameraConfig, camera_config: CameraConfig,
thumbnails: list[bytes], thumbnails: list[bytes],
metadata: dict[str, any], label: str,
) -> Optional[str]: ) -> Optional[str]:
"""Generate a description for the frame.""" """Generate a description for the frame."""
prompt = camera_config.genai.object_prompts.get( prompt = camera_config.genai.object_prompts.get(
metadata["label"], camera_config.genai.prompt label, camera_config.genai.prompt
).format(**metadata) )
return self._send(prompt, thumbnails) return self._send(prompt, thumbnails)
def _init_provider(self): def _init_provider(self):

View File

@ -1,3 +1,4 @@
from enum import Enum
from typing import TypedDict from typing import TypedDict
from frigate.camera import CameraMetrics from frigate.camera import CameraMetrics
@ -11,3 +12,10 @@ class StatsTrackingTypes(TypedDict):
latest_frigate_version: str latest_frigate_version: str
last_updated: int last_updated: int
processes: dict[str, int] processes: dict[str, int]
class ModelStatusTypesEnum(str, Enum):
not_downloaded = "not_downloaded"
downloading = "downloading"
downloaded = "downloaded"
error = "error"

123
frigate/util/downloader.py Normal file
View File

@ -0,0 +1,123 @@
import logging
import os
import threading
import time
from pathlib import Path
from typing import Callable, List
import requests
from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum
logger = logging.getLogger(__name__)
class FileLock:
def __init__(self, path):
self.path = path
self.lock_file = f"{path}.lock"
def acquire(self):
parent_dir = os.path.dirname(self.lock_file)
os.makedirs(parent_dir, exist_ok=True)
while True:
try:
with open(self.lock_file, "x"):
return
except FileExistsError:
time.sleep(0.1)
def release(self):
try:
os.remove(self.lock_file)
except FileNotFoundError:
pass
class ModelDownloader:
def __init__(
self,
model_name: str,
download_path: str,
file_names: List[str],
download_func: Callable[[str], None],
silent: bool = False,
):
self.model_name = model_name
self.download_path = download_path
self.file_names = file_names
self.download_func = download_func
self.silent = silent
self.requestor = InterProcessRequestor()
self.download_thread = None
self.download_complete = threading.Event()
def ensure_model_files(self):
for file in self.file_names:
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{file}",
"state": ModelStatusTypesEnum.downloading,
},
)
self.download_thread = threading.Thread(
target=self._download_models,
name=f"_download_model_{self.model_name}",
daemon=True,
)
self.download_thread.start()
def _download_models(self):
for file_name in self.file_names:
path = os.path.join(self.download_path, file_name)
lock = FileLock(path)
if not os.path.exists(path):
lock.acquire()
try:
if not os.path.exists(path):
self.download_func(path)
finally:
lock.release()
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{file_name}",
"state": ModelStatusTypesEnum.downloaded,
},
)
self.download_complete.set()
@staticmethod
def download_from_url(url: str, save_path: str, silent: bool = False):
temporary_filename = Path(save_path).with_name(
os.path.basename(save_path) + ".part"
)
temporary_filename.parent.mkdir(parents=True, exist_ok=True)
if not silent:
logger.info(f"Downloading model file from: {url}")
try:
with requests.get(url, stream=True, allow_redirects=True) as r:
r.raise_for_status()
with open(temporary_filename, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
temporary_filename.rename(save_path)
except Exception as e:
logger.error(f"Error downloading model: {str(e)}")
raise
if not silent:
logger.info(f"Downloading complete: {url}")
def wait_for_download(self):
self.download_complete.wait()

View File

@ -5,6 +5,7 @@ import {
FrigateCameraState, FrigateCameraState,
FrigateEvent, FrigateEvent,
FrigateReview, FrigateReview,
ModelState,
ToggleableSetting, ToggleableSetting,
} from "@/types/ws"; } from "@/types/ws";
import { FrigateStats } from "@/types/stats"; import { FrigateStats } from "@/types/stats";
@ -266,6 +267,41 @@ export function useInitialCameraState(
return { payload: data ? data[camera] : undefined }; return { payload: data ? data[camera] : undefined };
} }
export function useModelState(
model: string,
revalidateOnFocus: boolean = true,
): { payload: ModelState } {
const {
value: { payload },
send: sendCommand,
} = useWs("model_state", "modelState");
const data = useDeepMemo(JSON.parse(payload as string));
useEffect(() => {
let listener = undefined;
if (revalidateOnFocus) {
sendCommand("modelState");
listener = () => {
if (document.visibilityState == "visible") {
sendCommand("modelState");
}
};
addEventListener("visibilitychange", listener);
}
return () => {
if (listener) {
removeEventListener("visibilitychange", listener);
}
};
// we know that these deps are correct
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [revalidateOnFocus]);
return { payload: data ? data[model] : undefined };
}
export function useMotionActivity(camera: string): { payload: string } { export function useMotionActivity(camera: string): { payload: string } {
const { const {
value: { payload }, value: { payload },

View File

@ -52,12 +52,11 @@ export default function SearchThumbnail({
className="absolute inset-0" className="absolute inset-0"
imgLoaded={imgLoaded} imgLoaded={imgLoaded}
/> />
<div className={`${imgLoaded ? "visible" : "invisible"}`}> <div className={`size-full ${imgLoaded ? "visible" : "invisible"}`}>
<img <img
ref={imgRef} ref={imgRef}
className={cn( className={cn(
"size-full select-none opacity-100 transition-opacity", "size-full select-none object-cover object-center opacity-100 transition-opacity",
searchResult.search_source == "thumbnail" && "object-contain",
)} )}
style={ style={
isIOS isIOS

View File

@ -1,11 +1,15 @@
import { useEventUpdate } from "@/api/ws"; import { useEventUpdate, useModelState } from "@/api/ws";
import ActivityIndicator from "@/components/indicators/activity-indicator";
import { useApiFilterArgs } from "@/hooks/use-api-filter"; import { useApiFilterArgs } from "@/hooks/use-api-filter";
import { useTimezone } from "@/hooks/use-date-utils"; import { useTimezone } from "@/hooks/use-date-utils";
import { FrigateConfig } from "@/types/frigateConfig"; import { FrigateConfig } from "@/types/frigateConfig";
import { SearchFilter, SearchQuery, SearchResult } from "@/types/search"; import { SearchFilter, SearchQuery, SearchResult } from "@/types/search";
import { ModelState } from "@/types/ws";
import SearchView from "@/views/search/SearchView"; import SearchView from "@/views/search/SearchView";
import { useCallback, useEffect, useMemo, useState } from "react"; import { useCallback, useEffect, useMemo, useState } from "react";
import { LuCheck, LuExternalLink, LuX } from "react-icons/lu";
import { TbExclamationCircle } from "react-icons/tb"; import { TbExclamationCircle } from "react-icons/tb";
import { Link } from "react-router-dom";
import useSWR from "swr"; import useSWR from "swr";
import useSWRInfinite from "swr/infinite"; import useSWRInfinite from "swr/infinite";
@ -111,14 +115,10 @@ export default function Explore() {
// paging // paging
// usually slow only on first run while downloading models
const [isSlowLoading, setIsSlowLoading] = useState(false);
const getKey = ( const getKey = (
pageIndex: number, pageIndex: number,
previousPageData: SearchResult[] | null, previousPageData: SearchResult[] | null,
): SearchQuery => { ): SearchQuery => {
if (isSlowLoading && !similaritySearch) return null;
if (previousPageData && !previousPageData.length) return null; // reached the end if (previousPageData && !previousPageData.length) return null; // reached the end
if (!searchQuery) return null; if (!searchQuery) return null;
@ -143,12 +143,6 @@ export default function Explore() {
revalidateFirstPage: true, revalidateFirstPage: true,
revalidateOnFocus: true, revalidateOnFocus: true,
revalidateAll: false, revalidateAll: false,
onLoadingSlow: () => {
if (!similaritySearch) {
setIsSlowLoading(true);
}
},
loadingTimeout: 15000,
}); });
const searchResults = useMemo( const searchResults = useMemo(
@ -168,7 +162,7 @@ export default function Explore() {
if (searchQuery) { if (searchQuery) {
const [url] = searchQuery; const [url] = searchQuery;
// for chroma, only load 100 results for description and similarity // for embeddings, only load 100 results for description and similarity
if (url === "events/search" && searchResults.length >= 100) { if (url === "events/search" && searchResults.length >= 100) {
return; return;
} }
@ -188,17 +182,113 @@ export default function Explore() {
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, [eventUpdate]); }, [eventUpdate]);
// model states
const { payload: minilmModelState } = useModelState(
"sentence-transformers/all-MiniLM-L6-v2-model.onnx",
);
const { payload: minilmTokenizerState } = useModelState(
"sentence-transformers/all-MiniLM-L6-v2-tokenizer",
);
const { payload: clipImageModelState } = useModelState(
"clip-clip_image_model_vitb32.onnx",
);
const { payload: clipTextModelState } = useModelState(
"clip-clip_text_model_vitb32.onnx",
);
const allModelsLoaded = useMemo(() => {
return (
minilmModelState === "downloaded" &&
minilmTokenizerState === "downloaded" &&
clipImageModelState === "downloaded" &&
clipTextModelState === "downloaded"
);
}, [
minilmModelState,
minilmTokenizerState,
clipImageModelState,
clipTextModelState,
]);
const renderModelStateIcon = (modelState: ModelState) => {
if (modelState === "downloading") {
return <ActivityIndicator className="size-5" />;
}
if (modelState === "downloaded") {
return <LuCheck className="size-5 text-success" />;
}
if (modelState === "not_downloaded" || modelState === "error") {
return <LuX className="size-5 text-danger" />;
}
return null;
};
if (
!minilmModelState ||
!minilmTokenizerState ||
!clipImageModelState ||
!clipTextModelState
) {
return (
<ActivityIndicator className="absolute left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2" />
);
}
return ( return (
<> <>
{isSlowLoading && !similaritySearch ? ( {!allModelsLoaded ? (
<div className="absolute inset-0 left-1/2 top-1/2 flex h-96 w-96 -translate-x-1/2 -translate-y-1/2"> <div className="absolute inset-0 left-1/2 top-1/2 flex h-96 w-96 -translate-x-1/2 -translate-y-1/2">
<div className="flex flex-col items-center justify-center rounded-lg bg-background/50 p-5"> <div className="flex flex-col items-center justify-center space-y-3 rounded-lg bg-background/50 p-5">
<p className="my-5 text-lg">Search Unavailable</p> <div className="my-5 flex flex-col items-center gap-2 text-xl">
<TbExclamationCircle className="mb-3 size-10" /> <TbExclamationCircle className="mb-3 size-10" />
<p className="max-w-96 text-center"> <div>Search Unavailable</div>
If this is your first time using Search, be patient while Frigate </div>
downloads the necessary embeddings models. Check Frigate logs. <div className="max-w-96 text-center">
</p> Frigate is downloading the necessary embeddings models to support
semantic searching. This may take several minutes depending on the
speed of your network connection.
</div>
<div className="flex w-96 flex-col gap-2 py-5">
<div className="flex flex-row items-center justify-center gap-2">
{renderModelStateIcon(clipImageModelState)}
CLIP image model
</div>
<div className="flex flex-row items-center justify-center gap-2">
{renderModelStateIcon(clipTextModelState)}
CLIP text model
</div>
<div className="flex flex-row items-center justify-center gap-2">
{renderModelStateIcon(minilmModelState)}
MiniLM sentence model
</div>
<div className="flex flex-row items-center justify-center gap-2">
{renderModelStateIcon(minilmTokenizerState)}
MiniLM tokenizer
</div>
</div>
{(minilmModelState === "error" ||
clipImageModelState === "error" ||
clipTextModelState === "error") && (
<div className="my-3 max-w-96 text-center text-danger">
An error has occurred. Check Frigate logs.
</div>
)}
<div className="max-w-96 text-center">
You may want to reindex the embeddings of your tracked objects
once the models are downloaded.
</div>
<div className="flex max-w-96 items-center text-primary-variant">
<Link
to="https://docs.frigate.video/configuration/semantic_search"
target="_blank"
rel="noopener noreferrer"
className="inline"
>
Read the documentation{" "}
<LuExternalLink className="ml-2 inline-flex size-3" />
</Link>
</div>
</div> </div>
</div> </div>
) : ( ) : (

View File

@ -12,5 +12,5 @@ export type LogLine = {
content: string; content: string;
}; };
export const logTypes = ["frigate", "go2rtc", "nginx", "chroma"] as const; export const logTypes = ["frigate", "go2rtc", "nginx"] as const;
export type LogType = (typeof logTypes)[number]; export type LogType = (typeof logTypes)[number];

View File

@ -56,4 +56,10 @@ export interface FrigateCameraState {
objects: ObjectType[]; objects: ObjectType[];
} }
export type ModelState =
| "not_downloaded"
| "downloading"
| "downloaded"
| "error";
export type ToggleableSetting = "ON" | "OFF"; export type ToggleableSetting = "ON" | "OFF";

View File

@ -128,46 +128,6 @@ export function parseLogLines(logService: LogType, logs: string[]) {
}; };
}) })
.filter((value) => value != null) as LogLine[]; .filter((value) => value != null) as LogLine[];
} else if (logService == "chroma") {
return logs
.map((line) => {
const match = frigateDateStamp.exec(line);
if (!match) {
const infoIndex = line.indexOf("[INFO]");
if (infoIndex != -1) {
return {
dateStamp: line.substring(0, 19),
severity: "info",
section: "startup",
content: line.substring(infoIndex + 6).trim(),
};
}
return null;
}
const startup =
line.indexOf("Starting component") !== -1 ||
line.indexOf("startup") !== -1 ||
line.indexOf("Started") !== -1 ||
line.indexOf("Uvicorn") !== -1;
const api = !!httpMethods.exec(line);
const tag = startup ? "startup" : api ? "API" : "server";
return {
dateStamp: match.toString().slice(1, -1),
severity: pythonSeverity
.exec(line)
?.at(0)
?.toString()
?.toLowerCase() as LogSeverity,
section: tag,
content: line.substring(match.index + match[0].length).trim(),
};
})
.filter((value) => value != null) as LogLine[];
} }
return []; return [];

View File

@ -189,19 +189,9 @@ export default function SearchView({
// confidence score - probably needs tweaking // confidence score - probably needs tweaking
const zScoreToConfidence = (score: number, source: string) => { const zScoreToConfidence = (score: number) => {
let midpoint, scale;
if (source === "thumbnail") {
midpoint = 2;
scale = 0.5;
} else {
midpoint = 0.5;
scale = 1.5;
}
// Sigmoid function: 1 / (1 + e^x) // Sigmoid function: 1 / (1 + e^x)
const confidence = 1 / (1 + Math.exp((score - midpoint) * scale)); const confidence = 1 / (1 + Math.exp(score));
return Math.round(confidence * 100); return Math.round(confidence * 100);
}; };
@ -412,21 +402,13 @@ export default function SearchView({
) : ( ) : (
<LuText className="mr-1 size-3" /> <LuText className="mr-1 size-3" />
)} )}
{zScoreToConfidence( {zScoreToConfidence(value.search_distance)}%
value.search_distance,
value.search_source,
)}
%
</Chip> </Chip>
</TooltipTrigger> </TooltipTrigger>
<TooltipPortal> <TooltipPortal>
<TooltipContent> <TooltipContent>
Matched {value.search_source} at{" "} Matched {value.search_source} at{" "}
{zScoreToConfidence( {zScoreToConfidence(value.search_distance)}%
value.search_distance,
value.search_source,
)}
%
</TooltipContent> </TooltipContent>
</TooltipPortal> </TooltipPortal>
</Tooltip> </Tooltip>