diff --git a/frigate/__main__.py b/frigate/__main__.py index 7106f0209..844206908 100644 --- a/frigate/__main__.py +++ b/frigate/__main__.py @@ -1,12 +1,9 @@ import faulthandler -import sys import threading from flask import cli -# Hotsawp the sqlite3 module for Chroma compatibility -__import__("pysqlite3") -sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") +from frigate.app import FrigateApp faulthandler.enable() @@ -15,8 +12,6 @@ threading.current_thread().name = "frigate" cli.show_server_banner = lambda *x: None if __name__ == "__main__": - from frigate.app import FrigateApp - frigate_app = FrigateApp() frigate_app.start() diff --git a/frigate/api/app.py b/frigate/api/app.py index 5fec51c03..0e3b0fecd 100644 --- a/frigate/api/app.py +++ b/frigate/api/app.py @@ -23,6 +23,7 @@ from frigate.api.preview import PreviewBp from frigate.api.review import ReviewBp from frigate.config import FrigateConfig from frigate.const import CONFIG_DIR +from frigate.embeddings import EmbeddingsContext from frigate.events.external import ExternalEventProcessor from frigate.models import Event, Timeline from frigate.plus import PlusApi @@ -52,6 +53,7 @@ bp.register_blueprint(AuthBp) def create_app( frigate_config, database: SqliteQueueDatabase, + embeddings: EmbeddingsContext, detected_frames_processor, storage_maintainer: StorageMaintainer, onvif: OnvifController, @@ -79,6 +81,7 @@ def create_app( database.close() app.frigate_config = frigate_config + app.embeddings = embeddings app.detected_frames_processor = detected_frames_processor app.storage_maintainer = storage_maintainer app.onvif = onvif diff --git a/frigate/api/event.py b/frigate/api/event.py index 267d13bc9..0ecb9ddbd 100644 --- a/frigate/api/event.py +++ b/frigate/api/event.py @@ -1,5 +1,7 @@ """Event apis.""" +import base64 +import io import logging import os from datetime import datetime @@ -8,6 +10,7 @@ from pathlib import Path from urllib.parse import unquote import cv2 +import numpy as np from flask import ( Blueprint, current_app, @@ -15,13 +18,16 @@ from flask import ( make_response, request, ) -from peewee import DoesNotExist, fn, operator +from peewee import JOIN, DoesNotExist, fn, operator +from PIL import Image from playhouse.shortcuts import model_to_dict from frigate.const import ( CLIPS_DIR, ) -from frigate.models import Event, Timeline +from frigate.embeddings import EmbeddingsContext +from frigate.embeddings.embeddings import get_metadata +from frigate.models import Event, ReviewSegment, Timeline from frigate.object_processing import TrackedObject from frigate.util.builtin import get_tz_modifiers @@ -245,6 +251,189 @@ def events(): return jsonify(list(events)) +@EventBp.route("/events/search") +def events_search(): + query = request.args.get("query", type=str) + search_type = request.args.get("search_type", "text", type=str) + include_thumbnails = request.args.get("include_thumbnails", default=1, type=int) + limit = request.args.get("limit", 50, type=int) + + # Filters + cameras = request.args.get("cameras", "all", type=str) + labels = request.args.get("labels", "all", type=str) + zones = request.args.get("zones", "all", type=str) + after = request.args.get("after", type=float) + before = request.args.get("before", type=float) + + if not query: + return make_response( + jsonify( + { + "success": False, + "message": "A search query must be supplied", + } + ), + 400, + ) + + if not current_app.frigate_config.semantic_search.enabled: + return make_response( + jsonify( + { + "success": False, + "message": "Semantic search is not enabled", + } + ), + 400, + ) + + context: EmbeddingsContext = current_app.embeddings + + selected_columns = [ + Event.id, + Event.camera, + Event.label, + Event.sub_label, + Event.zones, + Event.start_time, + Event.end_time, + Event.data, + ReviewSegment.thumb_path, + ] + + if include_thumbnails: + selected_columns.append(Event.thumbnail) + + # Build the where clause for the embeddings query + embeddings_filters = [] + + if cameras != "all": + camera_list = cameras.split(",") + embeddings_filters.append({"camera": {"$in": camera_list}}) + + if labels != "all": + label_list = labels.split(",") + embeddings_filters.append({"label": {"$in": label_list}}) + + if zones != "all": + filtered_zones = zones.split(",") + zone_filters = [{f"zones_{zone}": {"$eq": True}} for zone in filtered_zones] + if len(zone_filters) > 1: + embeddings_filters.append({"$or": zone_filters}) + else: + embeddings_filters.append(zone_filters[0]) + + if after: + embeddings_filters.append({"start_time": {"$gt": after}}) + + if before: + embeddings_filters.append({"start_time": {"$lt": before}}) + + where = None + if len(embeddings_filters) > 1: + where = {"$and": embeddings_filters} + elif len(embeddings_filters) == 1: + where = embeddings_filters[0] + + thumb_ids = {} + desc_ids = {} + + if search_type == "thumbnail": + # Grab the ids of events that match the thumbnail image embeddings + try: + search_event: Event = Event.get(Event.id == query) + except DoesNotExist: + return make_response( + jsonify( + { + "success": False, + "message": "Event not found", + } + ), + 404, + ) + thumbnail = base64.b64decode(search_event.thumbnail) + img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB")) + thumb_result = context.embeddings.thumbnail.query( + query_images=[img], + n_results=limit, + where=where, + ) + thumb_ids = dict(zip(thumb_result["ids"][0], thumb_result["distances"][0])) + else: + thumb_result = context.embeddings.thumbnail.query( + query_texts=[query], + n_results=limit, + where=where, + ) + # Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM. + thumb_ids = dict( + zip( + thumb_result["ids"][0], + context.thumb_stats.normalize(thumb_result["distances"][0]), + ) + ) + desc_result = context.embeddings.description.query( + query_texts=[query], + n_results=limit, + where=where, + ) + desc_ids = dict( + zip( + desc_result["ids"][0], + context.desc_stats.normalize(desc_result["distances"][0]), + ) + ) + + results = {} + for event_id in thumb_ids.keys() | desc_ids: + min_distance = min( + i + for i in (thumb_ids.get(event_id), desc_ids.get(event_id)) + if i is not None + ) + results[event_id] = { + "distance": min_distance, + "source": "thumbnail" + if min_distance == thumb_ids.get(event_id) + else "description", + } + + if not results: + return jsonify([]) + + # Get the event data + events = ( + Event.select(*selected_columns) + .join( + ReviewSegment, + JOIN.LEFT_OUTER, + on=(fn.json_extract(ReviewSegment.data, "$.detections").contains(Event.id)), + ) + .where(Event.id << list(results.keys())) + .dicts() + .iterator() + ) + events = list(events) + + events = [ + {k: v for k, v in event.items() if k != "data"} + | { + k: v + for k, v in event["data"].items() + if k in ["type", "score", "top_score", "description"] + } + | { + "search_distance": results[event["id"]]["distance"], + "search_source": results[event["id"]]["source"], + } + for event in events + ] + events = sorted(events, key=lambda x: x["search_distance"])[:limit] + + return jsonify(events) + + @EventBp.route("/events/summary") def events_summary(): tz_name = request.args.get("timezone", default="utc", type=str) @@ -604,6 +793,52 @@ def set_sub_label(id): ) +@EventBp.route("/events//description", methods=("POST",)) +def set_description(id): + try: + event: Event = Event.get(Event.id == id) + except DoesNotExist: + return make_response( + jsonify({"success": False, "message": "Event " + id + " not found"}), 404 + ) + + json: dict[str, any] = request.get_json(silent=True) or {} + new_description = json.get("description") + + if new_description is None or len(new_description) == 0: + return make_response( + jsonify( + { + "success": False, + "message": "description cannot be empty", + } + ), + 400, + ) + + event.data["description"] = new_description + event.save() + + # If semantic search is enabled, update the index + if current_app.frigate_config.semantic_search.enabled: + context: EmbeddingsContext = current_app.embeddings + context.embeddings.description.upsert( + documents=[new_description], + metadatas=[get_metadata(event)], + ids=[id], + ) + + return make_response( + jsonify( + { + "success": True, + "message": "Event " + id + " description set to " + new_description, + } + ), + 200, + ) + + @EventBp.route("/events/", methods=("DELETE",)) def delete_event(id): try: @@ -625,6 +860,11 @@ def delete_event(id): event.delete_instance() Timeline.delete().where(Timeline.source_id == id).execute() + # If semantic search is enabled, update the index + if current_app.frigate_config.semantic_search.enabled: + context: EmbeddingsContext = current_app.embeddings + context.embeddings.thumbnail.delete(ids=[id]) + context.embeddings.description.delete(ids=[id]) return make_response( jsonify({"success": True, "message": "Event " + id + " deleted"}), 200 ) diff --git a/frigate/app.py b/frigate/app.py index 840686f0a..ef9360354 100644 --- a/frigate/app.py +++ b/frigate/app.py @@ -37,8 +37,7 @@ from frigate.const import ( MODEL_CACHE_DIR, RECORD_DIR, ) -from frigate.embeddings import manage_embeddings -from frigate.embeddings.embeddings import Embeddings +from frigate.embeddings import EmbeddingsContext, manage_embeddings from frigate.events.audio import listen_to_audio from frigate.events.cleanup import EventCleanup from frigate.events.external import ExternalEventProcessor @@ -322,7 +321,7 @@ class FrigateApp: def init_embeddings_manager(self) -> None: # Create a client for other processes to use - self.embeddings = Embeddings() + self.embeddings = EmbeddingsContext() embedding_process = mp.Process( target=manage_embeddings, name="embeddings_manager", @@ -384,6 +383,7 @@ class FrigateApp: self.flask_app = create_app( self.config, self.db, + self.embeddings, self.detected_frames_processor, self.storage_maintainer, self.onvif_controller, @@ -811,6 +811,9 @@ class FrigateApp: self.frigate_watchdog.join() self.db.stop() + # Save embeddings stats to disk + self.embeddings.save_stats() + # Stop Communicators self.inter_process_communicator.stop() self.inter_config_updater.stop() diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py index 41af73c01..b3ad22874 100644 --- a/frigate/embeddings/__init__.py +++ b/frigate/embeddings/__init__.py @@ -1,9 +1,9 @@ """ChromaDB embeddings database.""" +import json import logging import multiprocessing as mp import signal -import sys import threading from types import FrameType from typing import Optional @@ -12,9 +12,14 @@ from playhouse.sqliteq import SqliteQueueDatabase from setproctitle import setproctitle from frigate.config import FrigateConfig +from frigate.const import CONFIG_DIR from frigate.models import Event from frigate.util.services import listen +from .embeddings import Embeddings +from .maintainer import EmbeddingMaintainer +from .util import ZScoreNormalization + logger = logging.getLogger(__name__) @@ -48,12 +53,6 @@ def manage_embeddings(config: FrigateConfig) -> None: models = [Event] db.bind(models) - # Hotsawp the sqlite3 module for Chroma compatibility - __import__("pysqlite3") - sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") - from .embeddings import Embeddings - from .maintainer import EmbeddingMaintainer - embeddings = Embeddings() # Check if we need to re-index events @@ -65,3 +64,28 @@ def manage_embeddings(config: FrigateConfig) -> None: stop_event, ) maintainer.start() + + +class EmbeddingsContext: + def __init__(self): + self.embeddings = Embeddings() + self.thumb_stats = ZScoreNormalization() + self.desc_stats = ZScoreNormalization() + + # load stats from disk + try: + with open(f"{CONFIG_DIR}/.search_stats.json", "r") as f: + data = json.loads(f.read()) + self.thumb_stats.from_dict(data["thumb_stats"]) + self.desc_stats.from_dict(data["desc_stats"]) + except FileNotFoundError: + pass + + def save_stats(self): + """Write the stats to disk as JSON on exit.""" + contents = { + "thumb_stats": self.thumb_stats.to_dict(), + "desc_stats": self.desc_stats.to_dict(), + } + with open(f"{CONFIG_DIR}/.search_stats.json", "w") as f: + f.write(json.dumps(contents)) diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index c7a688d12..58dd707bb 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -3,19 +3,32 @@ import base64 import io import logging +import sys import time import numpy as np -from chromadb import Collection -from chromadb import HttpClient as ChromaClient -from chromadb.config import Settings from PIL import Image from playhouse.shortcuts import model_to_dict from frigate.models import Event -from .functions.clip import ClipEmbedding -from .functions.minilm_l6_v2 import MiniLMEmbedding +# Hotsawp the sqlite3 module for Chroma compatibility +try: + from chromadb import Collection + from chromadb import HttpClient as ChromaClient + from chromadb.config import Settings + + from .functions.clip import ClipEmbedding + from .functions.minilm_l6_v2 import MiniLMEmbedding +except RuntimeError: + __import__("pysqlite3") + sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") + from chromadb import Collection + from chromadb import HttpClient as ChromaClient + from chromadb.config import Settings + + from .functions.clip import ClipEmbedding + from .functions.minilm_l6_v2 import MiniLMEmbedding logger = logging.getLogger(__name__) diff --git a/frigate/embeddings/util.py b/frigate/embeddings/util.py new file mode 100644 index 000000000..7550716c9 --- /dev/null +++ b/frigate/embeddings/util.py @@ -0,0 +1,47 @@ +"""Z-score normalization for search distance.""" + +import math + + +class ZScoreNormalization: + """Running Z-score normalization for search distance.""" + + def __init__(self): + self.n = 0 + self.mean = 0 + self.m2 = 0 + + @property + def variance(self): + return self.m2 / (self.n - 1) if self.n > 1 else 0.0 + + @property + def stddev(self): + return math.sqrt(self.variance) + + def normalize(self, distances: list[float]): + self._update(distances) + if self.stddev == 0: + return distances + return [(x - self.mean) / self.stddev for x in distances] + + def _update(self, distances: list[float]): + for x in distances: + self.n += 1 + delta = x - self.mean + self.mean += delta / self.n + delta2 = x - self.mean + self.m2 += delta * delta2 + + def to_dict(self): + return { + "n": self.n, + "mean": self.mean, + "m2": self.m2, + } + + def from_dict(self, data: dict): + self.n = data["n"] + self.mean = data["mean"] + self.m2 = data["m2"] + return self diff --git a/frigate/test/test_http.py b/frigate/test/test_http.py index f0cb927f4..936dc80e5 100644 --- a/frigate/test/test_http.py +++ b/frigate/test/test_http.py @@ -120,6 +120,7 @@ class TestHttp(unittest.TestCase): None, None, None, + None, PlusApi(), None, ) @@ -156,6 +157,7 @@ class TestHttp(unittest.TestCase): None, None, None, + None, PlusApi(), None, ) @@ -177,6 +179,7 @@ class TestHttp(unittest.TestCase): None, None, None, + None, PlusApi(), None, ) @@ -197,6 +200,7 @@ class TestHttp(unittest.TestCase): None, None, None, + None, PlusApi(), None, ) @@ -219,6 +223,7 @@ class TestHttp(unittest.TestCase): None, None, None, + None, PlusApi(), None, ) @@ -245,6 +250,7 @@ class TestHttp(unittest.TestCase): None, None, None, + None, PlusApi(), None, ) @@ -283,6 +289,7 @@ class TestHttp(unittest.TestCase): None, None, None, + None, PlusApi(), None, ) @@ -318,6 +325,7 @@ class TestHttp(unittest.TestCase): None, None, None, + None, PlusApi(), None, ) @@ -343,6 +351,7 @@ class TestHttp(unittest.TestCase): None, None, None, + None, PlusApi(), None, ) @@ -360,6 +369,7 @@ class TestHttp(unittest.TestCase): None, None, None, + None, PlusApi(), None, ) @@ -381,6 +391,7 @@ class TestHttp(unittest.TestCase): None, None, None, + None, PlusApi(), stats, )