Semantic Search API (#12105)

* initial event search api implementation

* fix lint

* fix tests

* move chromadb imports and pysqlite hotswap to fix tests

* remove unused import

* switch default limit to 50

* fix events accidently pulling inside chroma results loop
This commit is contained in:
Jason Hunter 2024-06-23 09:13:02 -04:00 committed by Nicolas Mowen
parent 36cbffcc5e
commit 9e825811f2
8 changed files with 359 additions and 23 deletions

View File

@ -1,12 +1,9 @@
import faulthandler
import sys
import threading
from flask import cli
# Hotsawp the sqlite3 module for Chroma compatibility
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
from frigate.app import FrigateApp
faulthandler.enable()
@ -15,8 +12,6 @@ threading.current_thread().name = "frigate"
cli.show_server_banner = lambda *x: None
if __name__ == "__main__":
from frigate.app import FrigateApp
frigate_app = FrigateApp()
frigate_app.start()

View File

@ -23,6 +23,7 @@ from frigate.api.preview import PreviewBp
from frigate.api.review import ReviewBp
from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR
from frigate.embeddings import EmbeddingsContext
from frigate.events.external import ExternalEventProcessor
from frigate.models import Event, Timeline
from frigate.plus import PlusApi
@ -52,6 +53,7 @@ bp.register_blueprint(AuthBp)
def create_app(
frigate_config,
database: SqliteQueueDatabase,
embeddings: EmbeddingsContext,
detected_frames_processor,
storage_maintainer: StorageMaintainer,
onvif: OnvifController,
@ -79,6 +81,7 @@ def create_app(
database.close()
app.frigate_config = frigate_config
app.embeddings = embeddings
app.detected_frames_processor = detected_frames_processor
app.storage_maintainer = storage_maintainer
app.onvif = onvif

View File

@ -1,5 +1,7 @@
"""Event apis."""
import base64
import io
import logging
import os
from datetime import datetime
@ -8,6 +10,7 @@ from pathlib import Path
from urllib.parse import unquote
import cv2
import numpy as np
from flask import (
Blueprint,
current_app,
@ -15,13 +18,16 @@ from flask import (
make_response,
request,
)
from peewee import DoesNotExist, fn, operator
from peewee import JOIN, DoesNotExist, fn, operator
from PIL import Image
from playhouse.shortcuts import model_to_dict
from frigate.const import (
CLIPS_DIR,
)
from frigate.models import Event, Timeline
from frigate.embeddings import EmbeddingsContext
from frigate.embeddings.embeddings import get_metadata
from frigate.models import Event, ReviewSegment, Timeline
from frigate.object_processing import TrackedObject
from frigate.util.builtin import get_tz_modifiers
@ -245,6 +251,189 @@ def events():
return jsonify(list(events))
@EventBp.route("/events/search")
def events_search():
query = request.args.get("query", type=str)
search_type = request.args.get("search_type", "text", type=str)
include_thumbnails = request.args.get("include_thumbnails", default=1, type=int)
limit = request.args.get("limit", 50, type=int)
# Filters
cameras = request.args.get("cameras", "all", type=str)
labels = request.args.get("labels", "all", type=str)
zones = request.args.get("zones", "all", type=str)
after = request.args.get("after", type=float)
before = request.args.get("before", type=float)
if not query:
return make_response(
jsonify(
{
"success": False,
"message": "A search query must be supplied",
}
),
400,
)
if not current_app.frigate_config.semantic_search.enabled:
return make_response(
jsonify(
{
"success": False,
"message": "Semantic search is not enabled",
}
),
400,
)
context: EmbeddingsContext = current_app.embeddings
selected_columns = [
Event.id,
Event.camera,
Event.label,
Event.sub_label,
Event.zones,
Event.start_time,
Event.end_time,
Event.data,
ReviewSegment.thumb_path,
]
if include_thumbnails:
selected_columns.append(Event.thumbnail)
# Build the where clause for the embeddings query
embeddings_filters = []
if cameras != "all":
camera_list = cameras.split(",")
embeddings_filters.append({"camera": {"$in": camera_list}})
if labels != "all":
label_list = labels.split(",")
embeddings_filters.append({"label": {"$in": label_list}})
if zones != "all":
filtered_zones = zones.split(",")
zone_filters = [{f"zones_{zone}": {"$eq": True}} for zone in filtered_zones]
if len(zone_filters) > 1:
embeddings_filters.append({"$or": zone_filters})
else:
embeddings_filters.append(zone_filters[0])
if after:
embeddings_filters.append({"start_time": {"$gt": after}})
if before:
embeddings_filters.append({"start_time": {"$lt": before}})
where = None
if len(embeddings_filters) > 1:
where = {"$and": embeddings_filters}
elif len(embeddings_filters) == 1:
where = embeddings_filters[0]
thumb_ids = {}
desc_ids = {}
if search_type == "thumbnail":
# Grab the ids of events that match the thumbnail image embeddings
try:
search_event: Event = Event.get(Event.id == query)
except DoesNotExist:
return make_response(
jsonify(
{
"success": False,
"message": "Event not found",
}
),
404,
)
thumbnail = base64.b64decode(search_event.thumbnail)
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
thumb_result = context.embeddings.thumbnail.query(
query_images=[img],
n_results=limit,
where=where,
)
thumb_ids = dict(zip(thumb_result["ids"][0], thumb_result["distances"][0]))
else:
thumb_result = context.embeddings.thumbnail.query(
query_texts=[query],
n_results=limit,
where=where,
)
# Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM.
thumb_ids = dict(
zip(
thumb_result["ids"][0],
context.thumb_stats.normalize(thumb_result["distances"][0]),
)
)
desc_result = context.embeddings.description.query(
query_texts=[query],
n_results=limit,
where=where,
)
desc_ids = dict(
zip(
desc_result["ids"][0],
context.desc_stats.normalize(desc_result["distances"][0]),
)
)
results = {}
for event_id in thumb_ids.keys() | desc_ids:
min_distance = min(
i
for i in (thumb_ids.get(event_id), desc_ids.get(event_id))
if i is not None
)
results[event_id] = {
"distance": min_distance,
"source": "thumbnail"
if min_distance == thumb_ids.get(event_id)
else "description",
}
if not results:
return jsonify([])
# Get the event data
events = (
Event.select(*selected_columns)
.join(
ReviewSegment,
JOIN.LEFT_OUTER,
on=(fn.json_extract(ReviewSegment.data, "$.detections").contains(Event.id)),
)
.where(Event.id << list(results.keys()))
.dicts()
.iterator()
)
events = list(events)
events = [
{k: v for k, v in event.items() if k != "data"}
| {
k: v
for k, v in event["data"].items()
if k in ["type", "score", "top_score", "description"]
}
| {
"search_distance": results[event["id"]]["distance"],
"search_source": results[event["id"]]["source"],
}
for event in events
]
events = sorted(events, key=lambda x: x["search_distance"])[:limit]
return jsonify(events)
@EventBp.route("/events/summary")
def events_summary():
tz_name = request.args.get("timezone", default="utc", type=str)
@ -604,6 +793,52 @@ def set_sub_label(id):
)
@EventBp.route("/events/<id>/description", methods=("POST",))
def set_description(id):
try:
event: Event = Event.get(Event.id == id)
except DoesNotExist:
return make_response(
jsonify({"success": False, "message": "Event " + id + " not found"}), 404
)
json: dict[str, any] = request.get_json(silent=True) or {}
new_description = json.get("description")
if new_description is None or len(new_description) == 0:
return make_response(
jsonify(
{
"success": False,
"message": "description cannot be empty",
}
),
400,
)
event.data["description"] = new_description
event.save()
# If semantic search is enabled, update the index
if current_app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = current_app.embeddings
context.embeddings.description.upsert(
documents=[new_description],
metadatas=[get_metadata(event)],
ids=[id],
)
return make_response(
jsonify(
{
"success": True,
"message": "Event " + id + " description set to " + new_description,
}
),
200,
)
@EventBp.route("/events/<id>", methods=("DELETE",))
def delete_event(id):
try:
@ -625,6 +860,11 @@ def delete_event(id):
event.delete_instance()
Timeline.delete().where(Timeline.source_id == id).execute()
# If semantic search is enabled, update the index
if current_app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = current_app.embeddings
context.embeddings.thumbnail.delete(ids=[id])
context.embeddings.description.delete(ids=[id])
return make_response(
jsonify({"success": True, "message": "Event " + id + " deleted"}), 200
)

View File

@ -37,8 +37,7 @@ from frigate.const import (
MODEL_CACHE_DIR,
RECORD_DIR,
)
from frigate.embeddings import manage_embeddings
from frigate.embeddings.embeddings import Embeddings
from frigate.embeddings import EmbeddingsContext, manage_embeddings
from frigate.events.audio import listen_to_audio
from frigate.events.cleanup import EventCleanup
from frigate.events.external import ExternalEventProcessor
@ -322,7 +321,7 @@ class FrigateApp:
def init_embeddings_manager(self) -> None:
# Create a client for other processes to use
self.embeddings = Embeddings()
self.embeddings = EmbeddingsContext()
embedding_process = mp.Process(
target=manage_embeddings,
name="embeddings_manager",
@ -384,6 +383,7 @@ class FrigateApp:
self.flask_app = create_app(
self.config,
self.db,
self.embeddings,
self.detected_frames_processor,
self.storage_maintainer,
self.onvif_controller,
@ -811,6 +811,9 @@ class FrigateApp:
self.frigate_watchdog.join()
self.db.stop()
# Save embeddings stats to disk
self.embeddings.save_stats()
# Stop Communicators
self.inter_process_communicator.stop()
self.inter_config_updater.stop()

View File

@ -1,9 +1,9 @@
"""ChromaDB embeddings database."""
import json
import logging
import multiprocessing as mp
import signal
import sys
import threading
from types import FrameType
from typing import Optional
@ -12,9 +12,14 @@ from playhouse.sqliteq import SqliteQueueDatabase
from setproctitle import setproctitle
from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR
from frigate.models import Event
from frigate.util.services import listen
from .embeddings import Embeddings
from .maintainer import EmbeddingMaintainer
from .util import ZScoreNormalization
logger = logging.getLogger(__name__)
@ -48,12 +53,6 @@ def manage_embeddings(config: FrigateConfig) -> None:
models = [Event]
db.bind(models)
# Hotsawp the sqlite3 module for Chroma compatibility
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
from .embeddings import Embeddings
from .maintainer import EmbeddingMaintainer
embeddings = Embeddings()
# Check if we need to re-index events
@ -65,3 +64,28 @@ def manage_embeddings(config: FrigateConfig) -> None:
stop_event,
)
maintainer.start()
class EmbeddingsContext:
def __init__(self):
self.embeddings = Embeddings()
self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization()
# load stats from disk
try:
with open(f"{CONFIG_DIR}/.search_stats.json", "r") as f:
data = json.loads(f.read())
self.thumb_stats.from_dict(data["thumb_stats"])
self.desc_stats.from_dict(data["desc_stats"])
except FileNotFoundError:
pass
def save_stats(self):
"""Write the stats to disk as JSON on exit."""
contents = {
"thumb_stats": self.thumb_stats.to_dict(),
"desc_stats": self.desc_stats.to_dict(),
}
with open(f"{CONFIG_DIR}/.search_stats.json", "w") as f:
f.write(json.dumps(contents))

View File

@ -3,19 +3,32 @@
import base64
import io
import logging
import sys
import time
import numpy as np
from chromadb import Collection
from chromadb import HttpClient as ChromaClient
from chromadb.config import Settings
from PIL import Image
from playhouse.shortcuts import model_to_dict
from frigate.models import Event
from .functions.clip import ClipEmbedding
from .functions.minilm_l6_v2 import MiniLMEmbedding
# Hotsawp the sqlite3 module for Chroma compatibility
try:
from chromadb import Collection
from chromadb import HttpClient as ChromaClient
from chromadb.config import Settings
from .functions.clip import ClipEmbedding
from .functions.minilm_l6_v2 import MiniLMEmbedding
except RuntimeError:
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
from chromadb import Collection
from chromadb import HttpClient as ChromaClient
from chromadb.config import Settings
from .functions.clip import ClipEmbedding
from .functions.minilm_l6_v2 import MiniLMEmbedding
logger = logging.getLogger(__name__)

View File

@ -0,0 +1,47 @@
"""Z-score normalization for search distance."""
import math
class ZScoreNormalization:
"""Running Z-score normalization for search distance."""
def __init__(self):
self.n = 0
self.mean = 0
self.m2 = 0
@property
def variance(self):
return self.m2 / (self.n - 1) if self.n > 1 else 0.0
@property
def stddev(self):
return math.sqrt(self.variance)
def normalize(self, distances: list[float]):
self._update(distances)
if self.stddev == 0:
return distances
return [(x - self.mean) / self.stddev for x in distances]
def _update(self, distances: list[float]):
for x in distances:
self.n += 1
delta = x - self.mean
self.mean += delta / self.n
delta2 = x - self.mean
self.m2 += delta * delta2
def to_dict(self):
return {
"n": self.n,
"mean": self.mean,
"m2": self.m2,
}
def from_dict(self, data: dict):
self.n = data["n"]
self.mean = data["mean"]
self.m2 = data["m2"]
return self

View File

@ -120,6 +120,7 @@ class TestHttp(unittest.TestCase):
None,
None,
None,
None,
PlusApi(),
None,
)
@ -156,6 +157,7 @@ class TestHttp(unittest.TestCase):
None,
None,
None,
None,
PlusApi(),
None,
)
@ -177,6 +179,7 @@ class TestHttp(unittest.TestCase):
None,
None,
None,
None,
PlusApi(),
None,
)
@ -197,6 +200,7 @@ class TestHttp(unittest.TestCase):
None,
None,
None,
None,
PlusApi(),
None,
)
@ -219,6 +223,7 @@ class TestHttp(unittest.TestCase):
None,
None,
None,
None,
PlusApi(),
None,
)
@ -245,6 +250,7 @@ class TestHttp(unittest.TestCase):
None,
None,
None,
None,
PlusApi(),
None,
)
@ -283,6 +289,7 @@ class TestHttp(unittest.TestCase):
None,
None,
None,
None,
PlusApi(),
None,
)
@ -318,6 +325,7 @@ class TestHttp(unittest.TestCase):
None,
None,
None,
None,
PlusApi(),
None,
)
@ -343,6 +351,7 @@ class TestHttp(unittest.TestCase):
None,
None,
None,
None,
PlusApi(),
None,
)
@ -360,6 +369,7 @@ class TestHttp(unittest.TestCase):
None,
None,
None,
None,
PlusApi(),
None,
)
@ -381,6 +391,7 @@ class TestHttp(unittest.TestCase):
None,
None,
None,
None,
PlusApi(),
stats,
)