mirror of
				https://github.com/blakeblackshear/frigate.git
				synced 2025-10-27 10:52:11 +01:00 
			
		
		
		
	Semantic Search API (#12105)
* initial event search api implementation * fix lint * fix tests * move chromadb imports and pysqlite hotswap to fix tests * remove unused import * switch default limit to 50 * fix events accidently pulling inside chroma results loop
This commit is contained in:
		
							parent
							
								
									36cbffcc5e
								
							
						
					
					
						commit
						9e825811f2
					
				@ -1,12 +1,9 @@
 | 
				
			|||||||
import faulthandler
 | 
					import faulthandler
 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
import threading
 | 
					import threading
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from flask import cli
 | 
					from flask import cli
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Hotsawp the sqlite3 module for Chroma compatibility
 | 
					from frigate.app import FrigateApp
 | 
				
			||||||
__import__("pysqlite3")
 | 
					 | 
				
			||||||
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
faulthandler.enable()
 | 
					faulthandler.enable()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -15,8 +12,6 @@ threading.current_thread().name = "frigate"
 | 
				
			|||||||
cli.show_server_banner = lambda *x: None
 | 
					cli.show_server_banner = lambda *x: None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    from frigate.app import FrigateApp
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    frigate_app = FrigateApp()
 | 
					    frigate_app = FrigateApp()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    frigate_app.start()
 | 
					    frigate_app.start()
 | 
				
			||||||
 | 
				
			|||||||
@ -23,6 +23,7 @@ from frigate.api.preview import PreviewBp
 | 
				
			|||||||
from frigate.api.review import ReviewBp
 | 
					from frigate.api.review import ReviewBp
 | 
				
			||||||
from frigate.config import FrigateConfig
 | 
					from frigate.config import FrigateConfig
 | 
				
			||||||
from frigate.const import CONFIG_DIR
 | 
					from frigate.const import CONFIG_DIR
 | 
				
			||||||
 | 
					from frigate.embeddings import EmbeddingsContext
 | 
				
			||||||
from frigate.events.external import ExternalEventProcessor
 | 
					from frigate.events.external import ExternalEventProcessor
 | 
				
			||||||
from frigate.models import Event, Timeline
 | 
					from frigate.models import Event, Timeline
 | 
				
			||||||
from frigate.plus import PlusApi
 | 
					from frigate.plus import PlusApi
 | 
				
			||||||
@ -52,6 +53,7 @@ bp.register_blueprint(AuthBp)
 | 
				
			|||||||
def create_app(
 | 
					def create_app(
 | 
				
			||||||
    frigate_config,
 | 
					    frigate_config,
 | 
				
			||||||
    database: SqliteQueueDatabase,
 | 
					    database: SqliteQueueDatabase,
 | 
				
			||||||
 | 
					    embeddings: EmbeddingsContext,
 | 
				
			||||||
    detected_frames_processor,
 | 
					    detected_frames_processor,
 | 
				
			||||||
    storage_maintainer: StorageMaintainer,
 | 
					    storage_maintainer: StorageMaintainer,
 | 
				
			||||||
    onvif: OnvifController,
 | 
					    onvif: OnvifController,
 | 
				
			||||||
@ -79,6 +81,7 @@ def create_app(
 | 
				
			|||||||
            database.close()
 | 
					            database.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    app.frigate_config = frigate_config
 | 
					    app.frigate_config = frigate_config
 | 
				
			||||||
 | 
					    app.embeddings = embeddings
 | 
				
			||||||
    app.detected_frames_processor = detected_frames_processor
 | 
					    app.detected_frames_processor = detected_frames_processor
 | 
				
			||||||
    app.storage_maintainer = storage_maintainer
 | 
					    app.storage_maintainer = storage_maintainer
 | 
				
			||||||
    app.onvif = onvif
 | 
					    app.onvif = onvif
 | 
				
			||||||
 | 
				
			|||||||
@ -1,5 +1,7 @@
 | 
				
			|||||||
"""Event apis."""
 | 
					"""Event apis."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import base64
 | 
				
			||||||
 | 
					import io
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from datetime import datetime
 | 
					from datetime import datetime
 | 
				
			||||||
@ -8,6 +10,7 @@ from pathlib import Path
 | 
				
			|||||||
from urllib.parse import unquote
 | 
					from urllib.parse import unquote
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import cv2
 | 
					import cv2
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
from flask import (
 | 
					from flask import (
 | 
				
			||||||
    Blueprint,
 | 
					    Blueprint,
 | 
				
			||||||
    current_app,
 | 
					    current_app,
 | 
				
			||||||
@ -15,13 +18,16 @@ from flask import (
 | 
				
			|||||||
    make_response,
 | 
					    make_response,
 | 
				
			||||||
    request,
 | 
					    request,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from peewee import DoesNotExist, fn, operator
 | 
					from peewee import JOIN, DoesNotExist, fn, operator
 | 
				
			||||||
 | 
					from PIL import Image
 | 
				
			||||||
from playhouse.shortcuts import model_to_dict
 | 
					from playhouse.shortcuts import model_to_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from frigate.const import (
 | 
					from frigate.const import (
 | 
				
			||||||
    CLIPS_DIR,
 | 
					    CLIPS_DIR,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from frigate.models import Event, Timeline
 | 
					from frigate.embeddings import EmbeddingsContext
 | 
				
			||||||
 | 
					from frigate.embeddings.embeddings import get_metadata
 | 
				
			||||||
 | 
					from frigate.models import Event, ReviewSegment, Timeline
 | 
				
			||||||
from frigate.object_processing import TrackedObject
 | 
					from frigate.object_processing import TrackedObject
 | 
				
			||||||
from frigate.util.builtin import get_tz_modifiers
 | 
					from frigate.util.builtin import get_tz_modifiers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -245,6 +251,189 @@ def events():
 | 
				
			|||||||
    return jsonify(list(events))
 | 
					    return jsonify(list(events))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@EventBp.route("/events/search")
 | 
				
			||||||
 | 
					def events_search():
 | 
				
			||||||
 | 
					    query = request.args.get("query", type=str)
 | 
				
			||||||
 | 
					    search_type = request.args.get("search_type", "text", type=str)
 | 
				
			||||||
 | 
					    include_thumbnails = request.args.get("include_thumbnails", default=1, type=int)
 | 
				
			||||||
 | 
					    limit = request.args.get("limit", 50, type=int)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Filters
 | 
				
			||||||
 | 
					    cameras = request.args.get("cameras", "all", type=str)
 | 
				
			||||||
 | 
					    labels = request.args.get("labels", "all", type=str)
 | 
				
			||||||
 | 
					    zones = request.args.get("zones", "all", type=str)
 | 
				
			||||||
 | 
					    after = request.args.get("after", type=float)
 | 
				
			||||||
 | 
					    before = request.args.get("before", type=float)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not query:
 | 
				
			||||||
 | 
					        return make_response(
 | 
				
			||||||
 | 
					            jsonify(
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "success": False,
 | 
				
			||||||
 | 
					                    "message": "A search query must be supplied",
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            400,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not current_app.frigate_config.semantic_search.enabled:
 | 
				
			||||||
 | 
					        return make_response(
 | 
				
			||||||
 | 
					            jsonify(
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "success": False,
 | 
				
			||||||
 | 
					                    "message": "Semantic search is not enabled",
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            400,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    context: EmbeddingsContext = current_app.embeddings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    selected_columns = [
 | 
				
			||||||
 | 
					        Event.id,
 | 
				
			||||||
 | 
					        Event.camera,
 | 
				
			||||||
 | 
					        Event.label,
 | 
				
			||||||
 | 
					        Event.sub_label,
 | 
				
			||||||
 | 
					        Event.zones,
 | 
				
			||||||
 | 
					        Event.start_time,
 | 
				
			||||||
 | 
					        Event.end_time,
 | 
				
			||||||
 | 
					        Event.data,
 | 
				
			||||||
 | 
					        ReviewSegment.thumb_path,
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if include_thumbnails:
 | 
				
			||||||
 | 
					        selected_columns.append(Event.thumbnail)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Build the where clause for the embeddings query
 | 
				
			||||||
 | 
					    embeddings_filters = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if cameras != "all":
 | 
				
			||||||
 | 
					        camera_list = cameras.split(",")
 | 
				
			||||||
 | 
					        embeddings_filters.append({"camera": {"$in": camera_list}})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if labels != "all":
 | 
				
			||||||
 | 
					        label_list = labels.split(",")
 | 
				
			||||||
 | 
					        embeddings_filters.append({"label": {"$in": label_list}})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if zones != "all":
 | 
				
			||||||
 | 
					        filtered_zones = zones.split(",")
 | 
				
			||||||
 | 
					        zone_filters = [{f"zones_{zone}": {"$eq": True}} for zone in filtered_zones]
 | 
				
			||||||
 | 
					        if len(zone_filters) > 1:
 | 
				
			||||||
 | 
					            embeddings_filters.append({"$or": zone_filters})
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            embeddings_filters.append(zone_filters[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if after:
 | 
				
			||||||
 | 
					        embeddings_filters.append({"start_time": {"$gt": after}})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if before:
 | 
				
			||||||
 | 
					        embeddings_filters.append({"start_time": {"$lt": before}})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    where = None
 | 
				
			||||||
 | 
					    if len(embeddings_filters) > 1:
 | 
				
			||||||
 | 
					        where = {"$and": embeddings_filters}
 | 
				
			||||||
 | 
					    elif len(embeddings_filters) == 1:
 | 
				
			||||||
 | 
					        where = embeddings_filters[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    thumb_ids = {}
 | 
				
			||||||
 | 
					    desc_ids = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if search_type == "thumbnail":
 | 
				
			||||||
 | 
					        # Grab the ids of events that match the thumbnail image embeddings
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            search_event: Event = Event.get(Event.id == query)
 | 
				
			||||||
 | 
					        except DoesNotExist:
 | 
				
			||||||
 | 
					            return make_response(
 | 
				
			||||||
 | 
					                jsonify(
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        "success": False,
 | 
				
			||||||
 | 
					                        "message": "Event not found",
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					                404,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        thumbnail = base64.b64decode(search_event.thumbnail)
 | 
				
			||||||
 | 
					        img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
 | 
				
			||||||
 | 
					        thumb_result = context.embeddings.thumbnail.query(
 | 
				
			||||||
 | 
					            query_images=[img],
 | 
				
			||||||
 | 
					            n_results=limit,
 | 
				
			||||||
 | 
					            where=where,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        thumb_ids = dict(zip(thumb_result["ids"][0], thumb_result["distances"][0]))
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        thumb_result = context.embeddings.thumbnail.query(
 | 
				
			||||||
 | 
					            query_texts=[query],
 | 
				
			||||||
 | 
					            n_results=limit,
 | 
				
			||||||
 | 
					            where=where,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM.
 | 
				
			||||||
 | 
					        thumb_ids = dict(
 | 
				
			||||||
 | 
					            zip(
 | 
				
			||||||
 | 
					                thumb_result["ids"][0],
 | 
				
			||||||
 | 
					                context.thumb_stats.normalize(thumb_result["distances"][0]),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        desc_result = context.embeddings.description.query(
 | 
				
			||||||
 | 
					            query_texts=[query],
 | 
				
			||||||
 | 
					            n_results=limit,
 | 
				
			||||||
 | 
					            where=where,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        desc_ids = dict(
 | 
				
			||||||
 | 
					            zip(
 | 
				
			||||||
 | 
					                desc_result["ids"][0],
 | 
				
			||||||
 | 
					                context.desc_stats.normalize(desc_result["distances"][0]),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    results = {}
 | 
				
			||||||
 | 
					    for event_id in thumb_ids.keys() | desc_ids:
 | 
				
			||||||
 | 
					        min_distance = min(
 | 
				
			||||||
 | 
					            i
 | 
				
			||||||
 | 
					            for i in (thumb_ids.get(event_id), desc_ids.get(event_id))
 | 
				
			||||||
 | 
					            if i is not None
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        results[event_id] = {
 | 
				
			||||||
 | 
					            "distance": min_distance,
 | 
				
			||||||
 | 
					            "source": "thumbnail"
 | 
				
			||||||
 | 
					            if min_distance == thumb_ids.get(event_id)
 | 
				
			||||||
 | 
					            else "description",
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not results:
 | 
				
			||||||
 | 
					        return jsonify([])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Get the event data
 | 
				
			||||||
 | 
					    events = (
 | 
				
			||||||
 | 
					        Event.select(*selected_columns)
 | 
				
			||||||
 | 
					        .join(
 | 
				
			||||||
 | 
					            ReviewSegment,
 | 
				
			||||||
 | 
					            JOIN.LEFT_OUTER,
 | 
				
			||||||
 | 
					            on=(fn.json_extract(ReviewSegment.data, "$.detections").contains(Event.id)),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        .where(Event.id << list(results.keys()))
 | 
				
			||||||
 | 
					        .dicts()
 | 
				
			||||||
 | 
					        .iterator()
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    events = list(events)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    events = [
 | 
				
			||||||
 | 
					        {k: v for k, v in event.items() if k != "data"}
 | 
				
			||||||
 | 
					        | {
 | 
				
			||||||
 | 
					            k: v
 | 
				
			||||||
 | 
					            for k, v in event["data"].items()
 | 
				
			||||||
 | 
					            if k in ["type", "score", "top_score", "description"]
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        | {
 | 
				
			||||||
 | 
					            "search_distance": results[event["id"]]["distance"],
 | 
				
			||||||
 | 
					            "search_source": results[event["id"]]["source"],
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        for event in events
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    events = sorted(events, key=lambda x: x["search_distance"])[:limit]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return jsonify(events)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@EventBp.route("/events/summary")
 | 
					@EventBp.route("/events/summary")
 | 
				
			||||||
def events_summary():
 | 
					def events_summary():
 | 
				
			||||||
    tz_name = request.args.get("timezone", default="utc", type=str)
 | 
					    tz_name = request.args.get("timezone", default="utc", type=str)
 | 
				
			||||||
@ -604,6 +793,52 @@ def set_sub_label(id):
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@EventBp.route("/events/<id>/description", methods=("POST",))
 | 
				
			||||||
 | 
					def set_description(id):
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        event: Event = Event.get(Event.id == id)
 | 
				
			||||||
 | 
					    except DoesNotExist:
 | 
				
			||||||
 | 
					        return make_response(
 | 
				
			||||||
 | 
					            jsonify({"success": False, "message": "Event " + id + " not found"}), 404
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    json: dict[str, any] = request.get_json(silent=True) or {}
 | 
				
			||||||
 | 
					    new_description = json.get("description")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if new_description is None or len(new_description) == 0:
 | 
				
			||||||
 | 
					        return make_response(
 | 
				
			||||||
 | 
					            jsonify(
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "success": False,
 | 
				
			||||||
 | 
					                    "message": "description cannot be empty",
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            400,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    event.data["description"] = new_description
 | 
				
			||||||
 | 
					    event.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # If semantic search is enabled, update the index
 | 
				
			||||||
 | 
					    if current_app.frigate_config.semantic_search.enabled:
 | 
				
			||||||
 | 
					        context: EmbeddingsContext = current_app.embeddings
 | 
				
			||||||
 | 
					        context.embeddings.description.upsert(
 | 
				
			||||||
 | 
					            documents=[new_description],
 | 
				
			||||||
 | 
					            metadatas=[get_metadata(event)],
 | 
				
			||||||
 | 
					            ids=[id],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return make_response(
 | 
				
			||||||
 | 
					        jsonify(
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "success": True,
 | 
				
			||||||
 | 
					                "message": "Event " + id + " description set to " + new_description,
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        200,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@EventBp.route("/events/<id>", methods=("DELETE",))
 | 
					@EventBp.route("/events/<id>", methods=("DELETE",))
 | 
				
			||||||
def delete_event(id):
 | 
					def delete_event(id):
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
@ -625,6 +860,11 @@ def delete_event(id):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    event.delete_instance()
 | 
					    event.delete_instance()
 | 
				
			||||||
    Timeline.delete().where(Timeline.source_id == id).execute()
 | 
					    Timeline.delete().where(Timeline.source_id == id).execute()
 | 
				
			||||||
 | 
					    # If semantic search is enabled, update the index
 | 
				
			||||||
 | 
					    if current_app.frigate_config.semantic_search.enabled:
 | 
				
			||||||
 | 
					        context: EmbeddingsContext = current_app.embeddings
 | 
				
			||||||
 | 
					        context.embeddings.thumbnail.delete(ids=[id])
 | 
				
			||||||
 | 
					        context.embeddings.description.delete(ids=[id])
 | 
				
			||||||
    return make_response(
 | 
					    return make_response(
 | 
				
			||||||
        jsonify({"success": True, "message": "Event " + id + " deleted"}), 200
 | 
					        jsonify({"success": True, "message": "Event " + id + " deleted"}), 200
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
				
			|||||||
@ -37,8 +37,7 @@ from frigate.const import (
 | 
				
			|||||||
    MODEL_CACHE_DIR,
 | 
					    MODEL_CACHE_DIR,
 | 
				
			||||||
    RECORD_DIR,
 | 
					    RECORD_DIR,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from frigate.embeddings import manage_embeddings
 | 
					from frigate.embeddings import EmbeddingsContext, manage_embeddings
 | 
				
			||||||
from frigate.embeddings.embeddings import Embeddings
 | 
					 | 
				
			||||||
from frigate.events.audio import listen_to_audio
 | 
					from frigate.events.audio import listen_to_audio
 | 
				
			||||||
from frigate.events.cleanup import EventCleanup
 | 
					from frigate.events.cleanup import EventCleanup
 | 
				
			||||||
from frigate.events.external import ExternalEventProcessor
 | 
					from frigate.events.external import ExternalEventProcessor
 | 
				
			||||||
@ -322,7 +321,7 @@ class FrigateApp:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def init_embeddings_manager(self) -> None:
 | 
					    def init_embeddings_manager(self) -> None:
 | 
				
			||||||
        # Create a client for other processes to use
 | 
					        # Create a client for other processes to use
 | 
				
			||||||
        self.embeddings = Embeddings()
 | 
					        self.embeddings = EmbeddingsContext()
 | 
				
			||||||
        embedding_process = mp.Process(
 | 
					        embedding_process = mp.Process(
 | 
				
			||||||
            target=manage_embeddings,
 | 
					            target=manage_embeddings,
 | 
				
			||||||
            name="embeddings_manager",
 | 
					            name="embeddings_manager",
 | 
				
			||||||
@ -384,6 +383,7 @@ class FrigateApp:
 | 
				
			|||||||
        self.flask_app = create_app(
 | 
					        self.flask_app = create_app(
 | 
				
			||||||
            self.config,
 | 
					            self.config,
 | 
				
			||||||
            self.db,
 | 
					            self.db,
 | 
				
			||||||
 | 
					            self.embeddings,
 | 
				
			||||||
            self.detected_frames_processor,
 | 
					            self.detected_frames_processor,
 | 
				
			||||||
            self.storage_maintainer,
 | 
					            self.storage_maintainer,
 | 
				
			||||||
            self.onvif_controller,
 | 
					            self.onvif_controller,
 | 
				
			||||||
@ -811,6 +811,9 @@ class FrigateApp:
 | 
				
			|||||||
        self.frigate_watchdog.join()
 | 
					        self.frigate_watchdog.join()
 | 
				
			||||||
        self.db.stop()
 | 
					        self.db.stop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Save embeddings stats to disk
 | 
				
			||||||
 | 
					        self.embeddings.save_stats()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Stop Communicators
 | 
					        # Stop Communicators
 | 
				
			||||||
        self.inter_process_communicator.stop()
 | 
					        self.inter_process_communicator.stop()
 | 
				
			||||||
        self.inter_config_updater.stop()
 | 
					        self.inter_config_updater.stop()
 | 
				
			||||||
 | 
				
			|||||||
@ -1,9 +1,9 @@
 | 
				
			|||||||
"""ChromaDB embeddings database."""
 | 
					"""ChromaDB embeddings database."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import multiprocessing as mp
 | 
					import multiprocessing as mp
 | 
				
			||||||
import signal
 | 
					import signal
 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
import threading
 | 
					import threading
 | 
				
			||||||
from types import FrameType
 | 
					from types import FrameType
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
@ -12,9 +12,14 @@ from playhouse.sqliteq import SqliteQueueDatabase
 | 
				
			|||||||
from setproctitle import setproctitle
 | 
					from setproctitle import setproctitle
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from frigate.config import FrigateConfig
 | 
					from frigate.config import FrigateConfig
 | 
				
			||||||
 | 
					from frigate.const import CONFIG_DIR
 | 
				
			||||||
from frigate.models import Event
 | 
					from frigate.models import Event
 | 
				
			||||||
from frigate.util.services import listen
 | 
					from frigate.util.services import listen
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .embeddings import Embeddings
 | 
				
			||||||
 | 
					from .maintainer import EmbeddingMaintainer
 | 
				
			||||||
 | 
					from .util import ZScoreNormalization
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -48,12 +53,6 @@ def manage_embeddings(config: FrigateConfig) -> None:
 | 
				
			|||||||
    models = [Event]
 | 
					    models = [Event]
 | 
				
			||||||
    db.bind(models)
 | 
					    db.bind(models)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Hotsawp the sqlite3 module for Chroma compatibility
 | 
					 | 
				
			||||||
    __import__("pysqlite3")
 | 
					 | 
				
			||||||
    sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
 | 
					 | 
				
			||||||
    from .embeddings import Embeddings
 | 
					 | 
				
			||||||
    from .maintainer import EmbeddingMaintainer
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    embeddings = Embeddings()
 | 
					    embeddings = Embeddings()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Check if we need to re-index events
 | 
					    # Check if we need to re-index events
 | 
				
			||||||
@ -65,3 +64,28 @@ def manage_embeddings(config: FrigateConfig) -> None:
 | 
				
			|||||||
        stop_event,
 | 
					        stop_event,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    maintainer.start()
 | 
					    maintainer.start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EmbeddingsContext:
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        self.embeddings = Embeddings()
 | 
				
			||||||
 | 
					        self.thumb_stats = ZScoreNormalization()
 | 
				
			||||||
 | 
					        self.desc_stats = ZScoreNormalization()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # load stats from disk
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            with open(f"{CONFIG_DIR}/.search_stats.json", "r") as f:
 | 
				
			||||||
 | 
					                data = json.loads(f.read())
 | 
				
			||||||
 | 
					                self.thumb_stats.from_dict(data["thumb_stats"])
 | 
				
			||||||
 | 
					                self.desc_stats.from_dict(data["desc_stats"])
 | 
				
			||||||
 | 
					        except FileNotFoundError:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def save_stats(self):
 | 
				
			||||||
 | 
					        """Write the stats to disk as JSON on exit."""
 | 
				
			||||||
 | 
					        contents = {
 | 
				
			||||||
 | 
					            "thumb_stats": self.thumb_stats.to_dict(),
 | 
				
			||||||
 | 
					            "desc_stats": self.desc_stats.to_dict(),
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        with open(f"{CONFIG_DIR}/.search_stats.json", "w") as f:
 | 
				
			||||||
 | 
					            f.write(json.dumps(contents))
 | 
				
			||||||
 | 
				
			|||||||
@ -3,19 +3,32 @@
 | 
				
			|||||||
import base64
 | 
					import base64
 | 
				
			||||||
import io
 | 
					import io
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from chromadb import Collection
 | 
					 | 
				
			||||||
from chromadb import HttpClient as ChromaClient
 | 
					 | 
				
			||||||
from chromadb.config import Settings
 | 
					 | 
				
			||||||
from PIL import Image
 | 
					from PIL import Image
 | 
				
			||||||
from playhouse.shortcuts import model_to_dict
 | 
					from playhouse.shortcuts import model_to_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from frigate.models import Event
 | 
					from frigate.models import Event
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .functions.clip import ClipEmbedding
 | 
					# Hotsawp the sqlite3 module for Chroma compatibility
 | 
				
			||||||
from .functions.minilm_l6_v2 import MiniLMEmbedding
 | 
					try:
 | 
				
			||||||
 | 
					    from chromadb import Collection
 | 
				
			||||||
 | 
					    from chromadb import HttpClient as ChromaClient
 | 
				
			||||||
 | 
					    from chromadb.config import Settings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from .functions.clip import ClipEmbedding
 | 
				
			||||||
 | 
					    from .functions.minilm_l6_v2 import MiniLMEmbedding
 | 
				
			||||||
 | 
					except RuntimeError:
 | 
				
			||||||
 | 
					    __import__("pysqlite3")
 | 
				
			||||||
 | 
					    sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
 | 
				
			||||||
 | 
					    from chromadb import Collection
 | 
				
			||||||
 | 
					    from chromadb import HttpClient as ChromaClient
 | 
				
			||||||
 | 
					    from chromadb.config import Settings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from .functions.clip import ClipEmbedding
 | 
				
			||||||
 | 
					    from .functions.minilm_l6_v2 import MiniLMEmbedding
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										47
									
								
								frigate/embeddings/util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								frigate/embeddings/util.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,47 @@
 | 
				
			|||||||
 | 
					"""Z-score normalization for search distance."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ZScoreNormalization:
 | 
				
			||||||
 | 
					    """Running Z-score normalization for search distance."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        self.n = 0
 | 
				
			||||||
 | 
					        self.mean = 0
 | 
				
			||||||
 | 
					        self.m2 = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def variance(self):
 | 
				
			||||||
 | 
					        return self.m2 / (self.n - 1) if self.n > 1 else 0.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def stddev(self):
 | 
				
			||||||
 | 
					        return math.sqrt(self.variance)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def normalize(self, distances: list[float]):
 | 
				
			||||||
 | 
					        self._update(distances)
 | 
				
			||||||
 | 
					        if self.stddev == 0:
 | 
				
			||||||
 | 
					            return distances
 | 
				
			||||||
 | 
					        return [(x - self.mean) / self.stddev for x in distances]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _update(self, distances: list[float]):
 | 
				
			||||||
 | 
					        for x in distances:
 | 
				
			||||||
 | 
					            self.n += 1
 | 
				
			||||||
 | 
					            delta = x - self.mean
 | 
				
			||||||
 | 
					            self.mean += delta / self.n
 | 
				
			||||||
 | 
					            delta2 = x - self.mean
 | 
				
			||||||
 | 
					            self.m2 += delta * delta2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def to_dict(self):
 | 
				
			||||||
 | 
					        return {
 | 
				
			||||||
 | 
					            "n": self.n,
 | 
				
			||||||
 | 
					            "mean": self.mean,
 | 
				
			||||||
 | 
					            "m2": self.m2,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def from_dict(self, data: dict):
 | 
				
			||||||
 | 
					        self.n = data["n"]
 | 
				
			||||||
 | 
					        self.mean = data["mean"]
 | 
				
			||||||
 | 
					        self.m2 = data["m2"]
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
@ -120,6 +120,7 @@ class TestHttp(unittest.TestCase):
 | 
				
			|||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
            PlusApi(),
 | 
					            PlusApi(),
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@ -156,6 +157,7 @@ class TestHttp(unittest.TestCase):
 | 
				
			|||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
            PlusApi(),
 | 
					            PlusApi(),
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@ -177,6 +179,7 @@ class TestHttp(unittest.TestCase):
 | 
				
			|||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
            PlusApi(),
 | 
					            PlusApi(),
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@ -197,6 +200,7 @@ class TestHttp(unittest.TestCase):
 | 
				
			|||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
            PlusApi(),
 | 
					            PlusApi(),
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@ -219,6 +223,7 @@ class TestHttp(unittest.TestCase):
 | 
				
			|||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
            PlusApi(),
 | 
					            PlusApi(),
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@ -245,6 +250,7 @@ class TestHttp(unittest.TestCase):
 | 
				
			|||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
            PlusApi(),
 | 
					            PlusApi(),
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@ -283,6 +289,7 @@ class TestHttp(unittest.TestCase):
 | 
				
			|||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
            PlusApi(),
 | 
					            PlusApi(),
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@ -318,6 +325,7 @@ class TestHttp(unittest.TestCase):
 | 
				
			|||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
            PlusApi(),
 | 
					            PlusApi(),
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@ -343,6 +351,7 @@ class TestHttp(unittest.TestCase):
 | 
				
			|||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
            PlusApi(),
 | 
					            PlusApi(),
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@ -360,6 +369,7 @@ class TestHttp(unittest.TestCase):
 | 
				
			|||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
            PlusApi(),
 | 
					            PlusApi(),
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@ -381,6 +391,7 @@ class TestHttp(unittest.TestCase):
 | 
				
			|||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
            None,
 | 
					            None,
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
            PlusApi(),
 | 
					            PlusApi(),
 | 
				
			||||||
            stats,
 | 
					            stats,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user