mirror of
				https://github.com/blakeblackshear/frigate.git
				synced 2025-10-27 10:52:11 +01:00 
			
		
		
		
	Restructure embeddings (#14266)
* Restructure embeddings * Use ZMQ to proxy embeddings requests * Handle serialization * Formatting * Remove unused
This commit is contained in:
		
							parent
							
								
									a2ca18a714
								
							
						
					
					
						commit
						8ade85edec
					
				@ -472,7 +472,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
 | 
				
			|||||||
                status_code=404,
 | 
					                status_code=404,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        thumb_result = context.embeddings.search_thumbnail(search_event)
 | 
					        thumb_result = context.search_thumbnail(search_event)
 | 
				
			||||||
        thumb_ids = dict(
 | 
					        thumb_ids = dict(
 | 
				
			||||||
            zip(
 | 
					            zip(
 | 
				
			||||||
                [result[0] for result in thumb_result],
 | 
					                [result[0] for result in thumb_result],
 | 
				
			||||||
@ -487,7 +487,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
 | 
				
			|||||||
        search_types = search_type.split(",")
 | 
					        search_types = search_type.split(",")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if "thumbnail" in search_types:
 | 
					        if "thumbnail" in search_types:
 | 
				
			||||||
            thumb_result = context.embeddings.search_thumbnail(query)
 | 
					            thumb_result = context.search_thumbnail(query)
 | 
				
			||||||
            thumb_ids = dict(
 | 
					            thumb_ids = dict(
 | 
				
			||||||
                zip(
 | 
					                zip(
 | 
				
			||||||
                    [result[0] for result in thumb_result],
 | 
					                    [result[0] for result in thumb_result],
 | 
				
			||||||
@ -504,7 +504,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if "description" in search_types:
 | 
					        if "description" in search_types:
 | 
				
			||||||
            desc_result = context.embeddings.search_description(query)
 | 
					            desc_result = context.search_description(query)
 | 
				
			||||||
            desc_ids = dict(
 | 
					            desc_ids = dict(
 | 
				
			||||||
                zip(
 | 
					                zip(
 | 
				
			||||||
                    [result[0] for result in desc_result],
 | 
					                    [result[0] for result in desc_result],
 | 
				
			||||||
@ -944,9 +944,9 @@ def set_description(
 | 
				
			|||||||
    # If semantic search is enabled, update the index
 | 
					    # If semantic search is enabled, update the index
 | 
				
			||||||
    if request.app.frigate_config.semantic_search.enabled:
 | 
					    if request.app.frigate_config.semantic_search.enabled:
 | 
				
			||||||
        context: EmbeddingsContext = request.app.embeddings
 | 
					        context: EmbeddingsContext = request.app.embeddings
 | 
				
			||||||
        context.embeddings.upsert_description(
 | 
					        context.update_description(
 | 
				
			||||||
            event_id=event_id,
 | 
					            event_id,
 | 
				
			||||||
            description=new_description,
 | 
					            new_description,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    response_message = (
 | 
					    response_message = (
 | 
				
			||||||
@ -1033,8 +1033,8 @@ def delete_event(request: Request, event_id: str):
 | 
				
			|||||||
    # If semantic search is enabled, update the index
 | 
					    # If semantic search is enabled, update the index
 | 
				
			||||||
    if request.app.frigate_config.semantic_search.enabled:
 | 
					    if request.app.frigate_config.semantic_search.enabled:
 | 
				
			||||||
        context: EmbeddingsContext = request.app.embeddings
 | 
					        context: EmbeddingsContext = request.app.embeddings
 | 
				
			||||||
        context.embeddings.delete_thumbnail(id=[event_id])
 | 
					        context.db.delete_embeddings_thumbnail(id=[event_id])
 | 
				
			||||||
        context.embeddings.delete_description(id=[event_id])
 | 
					        context.db.delete_embeddings_description(id=[event_id])
 | 
				
			||||||
    return JSONResponse(
 | 
					    return JSONResponse(
 | 
				
			||||||
        content=({"success": True, "message": "Event " + event_id + " deleted"}),
 | 
					        content=({"success": True, "message": "Event " + event_id + " deleted"}),
 | 
				
			||||||
        status_code=200,
 | 
					        status_code=200,
 | 
				
			||||||
 | 
				
			|||||||
@ -276,7 +276,7 @@ class FrigateApp:
 | 
				
			|||||||
    def init_embeddings_client(self) -> None:
 | 
					    def init_embeddings_client(self) -> None:
 | 
				
			||||||
        if self.config.semantic_search.enabled:
 | 
					        if self.config.semantic_search.enabled:
 | 
				
			||||||
            # Create a client for other processes to use
 | 
					            # Create a client for other processes to use
 | 
				
			||||||
            self.embeddings = EmbeddingsContext(self.config, self.db)
 | 
					            self.embeddings = EmbeddingsContext(self.db)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def init_external_event_processor(self) -> None:
 | 
					    def init_external_event_processor(self) -> None:
 | 
				
			||||||
        self.external_event_processor = ExternalEventProcessor(self.config)
 | 
					        self.external_event_processor = ExternalEventProcessor(self.config)
 | 
				
			||||||
@ -699,7 +699,7 @@ class FrigateApp:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # Save embeddings stats to disk
 | 
					        # Save embeddings stats to disk
 | 
				
			||||||
        if self.embeddings:
 | 
					        if self.embeddings:
 | 
				
			||||||
            self.embeddings.save_stats()
 | 
					            self.embeddings.stop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Stop Communicators
 | 
					        # Stop Communicators
 | 
				
			||||||
        self.inter_process_communicator.stop()
 | 
					        self.inter_process_communicator.stop()
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										62
									
								
								frigate/comms/embeddings_updater.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								frigate/comms/embeddings_updater.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,62 @@
 | 
				
			|||||||
 | 
					"""Facilitates communication between processes."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from enum import Enum
 | 
				
			||||||
 | 
					from typing import Callable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import zmq
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EmbeddingsRequestEnum(Enum):
 | 
				
			||||||
 | 
					    embed_description = "embed_description"
 | 
				
			||||||
 | 
					    embed_thumbnail = "embed_thumbnail"
 | 
				
			||||||
 | 
					    generate_search = "generate_search"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EmbeddingsResponder:
 | 
				
			||||||
 | 
					    def __init__(self) -> None:
 | 
				
			||||||
 | 
					        self.context = zmq.Context()
 | 
				
			||||||
 | 
					        self.socket = self.context.socket(zmq.REP)
 | 
				
			||||||
 | 
					        self.socket.bind(SOCKET_REP_REQ)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def check_for_request(self, process: Callable) -> None:
 | 
				
			||||||
 | 
					        while True:  # load all messages that are queued
 | 
				
			||||||
 | 
					            has_message, _, _ = zmq.select([self.socket], [], [], 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if not has_message:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                (topic, value) = self.socket.recv_json(flags=zmq.NOBLOCK)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                response = process(topic, value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if response is not None:
 | 
				
			||||||
 | 
					                    self.socket.send_json(response)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    self.socket.send_json([])
 | 
				
			||||||
 | 
					            except zmq.ZMQError:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def stop(self) -> None:
 | 
				
			||||||
 | 
					        self.socket.close()
 | 
				
			||||||
 | 
					        self.context.destroy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EmbeddingsRequestor:
 | 
				
			||||||
 | 
					    """Simplifies sending data to EmbeddingsResponder and getting a reply."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self) -> None:
 | 
				
			||||||
 | 
					        self.context = zmq.Context()
 | 
				
			||||||
 | 
					        self.socket = self.context.socket(zmq.REQ)
 | 
				
			||||||
 | 
					        self.socket.connect(SOCKET_REP_REQ)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send_data(self, topic: str, data: any) -> str:
 | 
				
			||||||
 | 
					        """Sends data and then waits for reply."""
 | 
				
			||||||
 | 
					        self.socket.send_json((topic, data))
 | 
				
			||||||
 | 
					        return self.socket.recv_json()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def stop(self) -> None:
 | 
				
			||||||
 | 
					        self.socket.close()
 | 
				
			||||||
 | 
					        self.context.destroy()
 | 
				
			||||||
@ -20,3 +20,11 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
 | 
				
			|||||||
        conn.enable_load_extension(True)
 | 
					        conn.enable_load_extension(True)
 | 
				
			||||||
        conn.load_extension(self.sqlite_vec_path)
 | 
					        conn.load_extension(self.sqlite_vec_path)
 | 
				
			||||||
        conn.enable_load_extension(False)
 | 
					        conn.enable_load_extension(False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def delete_embeddings_thumbnail(self, event_ids: list[str]) -> None:
 | 
				
			||||||
 | 
					        ids = ",".join(["?" for _ in event_ids])
 | 
				
			||||||
 | 
					        self.execute_sql(f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def delete_embeddings_description(self, event_ids: list[str]) -> None:
 | 
				
			||||||
 | 
					        ids = ",".join(["?" for _ in event_ids])
 | 
				
			||||||
 | 
					        self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids)
 | 
				
			||||||
 | 
				
			|||||||
@ -7,14 +7,16 @@ import os
 | 
				
			|||||||
import signal
 | 
					import signal
 | 
				
			||||||
import threading
 | 
					import threading
 | 
				
			||||||
from types import FrameType
 | 
					from types import FrameType
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from setproctitle import setproctitle
 | 
					from setproctitle import setproctitle
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor
 | 
				
			||||||
from frigate.config import FrigateConfig
 | 
					from frigate.config import FrigateConfig
 | 
				
			||||||
from frigate.const import CONFIG_DIR
 | 
					from frigate.const import CONFIG_DIR
 | 
				
			||||||
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
 | 
					from frigate.db.sqlitevecq import SqliteVecQueueDatabase
 | 
				
			||||||
from frigate.models import Event
 | 
					from frigate.models import Event
 | 
				
			||||||
 | 
					from frigate.util.builtin import serialize
 | 
				
			||||||
from frigate.util.services import listen
 | 
					from frigate.util.services import listen
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .embeddings import Embeddings
 | 
					from .embeddings import Embeddings
 | 
				
			||||||
@ -70,10 +72,11 @@ def manage_embeddings(config: FrigateConfig) -> None:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class EmbeddingsContext:
 | 
					class EmbeddingsContext:
 | 
				
			||||||
    def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase):
 | 
					    def __init__(self, db: SqliteVecQueueDatabase):
 | 
				
			||||||
        self.embeddings = Embeddings(config.semantic_search, db)
 | 
					        self.db = db
 | 
				
			||||||
        self.thumb_stats = ZScoreNormalization()
 | 
					        self.thumb_stats = ZScoreNormalization()
 | 
				
			||||||
        self.desc_stats = ZScoreNormalization()
 | 
					        self.desc_stats = ZScoreNormalization()
 | 
				
			||||||
 | 
					        self.requestor = EmbeddingsRequestor()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # load stats from disk
 | 
					        # load stats from disk
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
@ -84,7 +87,7 @@ class EmbeddingsContext:
 | 
				
			|||||||
        except FileNotFoundError:
 | 
					        except FileNotFoundError:
 | 
				
			||||||
            pass
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def save_stats(self):
 | 
					    def stop(self):
 | 
				
			||||||
        """Write the stats to disk as JSON on exit."""
 | 
					        """Write the stats to disk as JSON on exit."""
 | 
				
			||||||
        contents = {
 | 
					        contents = {
 | 
				
			||||||
            "thumb_stats": self.thumb_stats.to_dict(),
 | 
					            "thumb_stats": self.thumb_stats.to_dict(),
 | 
				
			||||||
@ -92,3 +95,100 @@ class EmbeddingsContext:
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
        with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
 | 
					        with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
 | 
				
			||||||
            json.dump(contents, f)
 | 
					            json.dump(contents, f)
 | 
				
			||||||
 | 
					        self.requestor.stop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def search_thumbnail(
 | 
				
			||||||
 | 
					        self, query: Union[Event, str], event_ids: list[str] = None
 | 
				
			||||||
 | 
					    ) -> list[tuple[str, float]]:
 | 
				
			||||||
 | 
					        if query.__class__ == Event:
 | 
				
			||||||
 | 
					            cursor = self.db.execute_sql(
 | 
				
			||||||
 | 
					                """
 | 
				
			||||||
 | 
					                SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
 | 
				
			||||||
 | 
					                """,
 | 
				
			||||||
 | 
					                [query.id],
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            row = cursor.fetchone() if cursor else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if row:
 | 
				
			||||||
 | 
					                query_embedding = row[0]
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # If no embedding found, generate it and return it
 | 
				
			||||||
 | 
					                query_embedding = serialize(
 | 
				
			||||||
 | 
					                    self.requestor.send_data(
 | 
				
			||||||
 | 
					                        EmbeddingsRequestEnum.embed_thumbnail.value,
 | 
				
			||||||
 | 
					                        {"id": query.id, "thumbnail": query.thumbnail},
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            query_embedding = serialize(
 | 
				
			||||||
 | 
					                self.requestor.send_data(
 | 
				
			||||||
 | 
					                    EmbeddingsRequestEnum.generate_search.value, query
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sql_query = """
 | 
				
			||||||
 | 
					            SELECT
 | 
				
			||||||
 | 
					                id,
 | 
				
			||||||
 | 
					                distance
 | 
				
			||||||
 | 
					            FROM vec_thumbnails
 | 
				
			||||||
 | 
					            WHERE thumbnail_embedding MATCH ?
 | 
				
			||||||
 | 
					                AND k = 100
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Add the IN clause if event_ids is provided and not empty
 | 
				
			||||||
 | 
					        # this is the only filter supported by sqlite-vec as of 0.1.3
 | 
				
			||||||
 | 
					        # but it seems to be broken in this version
 | 
				
			||||||
 | 
					        if event_ids:
 | 
				
			||||||
 | 
					            sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # order by distance DESC is not implemented in this version of sqlite-vec
 | 
				
			||||||
 | 
					        # when it's implemented, we can use cosine similarity
 | 
				
			||||||
 | 
					        sql_query += " ORDER BY distance"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        parameters = [query_embedding] + event_ids if event_ids else [query_embedding]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        results = self.db.execute_sql(sql_query, parameters).fetchall()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def search_description(
 | 
				
			||||||
 | 
					        self, query_text: str, event_ids: list[str] = None
 | 
				
			||||||
 | 
					    ) -> list[tuple[str, float]]:
 | 
				
			||||||
 | 
					        query_embedding = serialize(
 | 
				
			||||||
 | 
					            self.requestor.send_data(
 | 
				
			||||||
 | 
					                EmbeddingsRequestEnum.generate_search.value, query_text
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Prepare the base SQL query
 | 
				
			||||||
 | 
					        sql_query = """
 | 
				
			||||||
 | 
					            SELECT
 | 
				
			||||||
 | 
					                id,
 | 
				
			||||||
 | 
					                distance
 | 
				
			||||||
 | 
					            FROM vec_descriptions
 | 
				
			||||||
 | 
					            WHERE description_embedding MATCH ?
 | 
				
			||||||
 | 
					                AND k = 100
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Add the IN clause if event_ids is provided and not empty
 | 
				
			||||||
 | 
					        # this is the only filter supported by sqlite-vec as of 0.1.3
 | 
				
			||||||
 | 
					        # but it seems to be broken in this version
 | 
				
			||||||
 | 
					        if event_ids:
 | 
				
			||||||
 | 
					            sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # order by distance DESC is not implemented in this version of sqlite-vec
 | 
				
			||||||
 | 
					        # when it's implemented, we can use cosine similarity
 | 
				
			||||||
 | 
					        sql_query += " ORDER BY distance"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        parameters = [query_embedding] + event_ids if event_ids else [query_embedding]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        results = self.db.execute_sql(sql_query, parameters).fetchall()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update_description(self, event_id: str, description: str) -> None:
 | 
				
			||||||
 | 
					        self.requestor.send_data(
 | 
				
			||||||
 | 
					            EmbeddingsRequestEnum.embed_description.value,
 | 
				
			||||||
 | 
					            {"id": event_id, "description": description},
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
				
			|||||||
@ -3,11 +3,8 @@
 | 
				
			|||||||
import base64
 | 
					import base64
 | 
				
			||||||
import io
 | 
					import io
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import struct
 | 
					 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
from typing import List, Tuple, Union
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					 | 
				
			||||||
from PIL import Image
 | 
					from PIL import Image
 | 
				
			||||||
from playhouse.shortcuts import model_to_dict
 | 
					from playhouse.shortcuts import model_to_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -17,6 +14,7 @@ from frigate.const import UPDATE_MODEL_STATE
 | 
				
			|||||||
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
 | 
					from frigate.db.sqlitevecq import SqliteVecQueueDatabase
 | 
				
			||||||
from frigate.models import Event
 | 
					from frigate.models import Event
 | 
				
			||||||
from frigate.types import ModelStatusTypesEnum
 | 
					from frigate.types import ModelStatusTypesEnum
 | 
				
			||||||
 | 
					from frigate.util.builtin import serialize
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .functions.onnx import GenericONNXEmbedding
 | 
					from .functions.onnx import GenericONNXEmbedding
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -54,30 +52,6 @@ def get_metadata(event: Event) -> dict:
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def serialize(vector: Union[List[float], np.ndarray, float]) -> bytes:
 | 
					 | 
				
			||||||
    """Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
 | 
					 | 
				
			||||||
    if isinstance(vector, np.ndarray):
 | 
					 | 
				
			||||||
        # Convert numpy array to list of floats
 | 
					 | 
				
			||||||
        vector = vector.flatten().tolist()
 | 
					 | 
				
			||||||
    elif isinstance(vector, (float, np.float32, np.float64)):
 | 
					 | 
				
			||||||
        # Handle single float values
 | 
					 | 
				
			||||||
        vector = [vector]
 | 
					 | 
				
			||||||
    elif not isinstance(vector, list):
 | 
					 | 
				
			||||||
        raise TypeError(
 | 
					 | 
				
			||||||
            f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        return struct.pack("%sf" % len(vector), *vector)
 | 
					 | 
				
			||||||
    except struct.error as e:
 | 
					 | 
				
			||||||
        raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def deserialize(bytes_data: bytes) -> List[float]:
 | 
					 | 
				
			||||||
    """Deserializes a compact "raw bytes" format into a list of floats"""
 | 
					 | 
				
			||||||
    return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Embeddings:
 | 
					class Embeddings:
 | 
				
			||||||
    """SQLite-vec embeddings database."""
 | 
					    """SQLite-vec embeddings database."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -190,106 +164,6 @@ class Embeddings:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        return embedding
 | 
					        return embedding
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def delete_thumbnail(self, event_ids: List[str]) -> None:
 | 
					 | 
				
			||||||
        ids = ",".join(["?" for _ in event_ids])
 | 
					 | 
				
			||||||
        self.db.execute_sql(
 | 
					 | 
				
			||||||
            f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def delete_description(self, event_ids: List[str]) -> None:
 | 
					 | 
				
			||||||
        ids = ",".join(["?" for _ in event_ids])
 | 
					 | 
				
			||||||
        self.db.execute_sql(
 | 
					 | 
				
			||||||
            f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def search_thumbnail(
 | 
					 | 
				
			||||||
        self, query: Union[Event, str], event_ids: List[str] = None
 | 
					 | 
				
			||||||
    ) -> List[Tuple[str, float]]:
 | 
					 | 
				
			||||||
        if query.__class__ == Event:
 | 
					 | 
				
			||||||
            cursor = self.db.execute_sql(
 | 
					 | 
				
			||||||
                """
 | 
					 | 
				
			||||||
                SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
 | 
					 | 
				
			||||||
                """,
 | 
					 | 
				
			||||||
                [query.id],
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            row = cursor.fetchone() if cursor else None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if row:
 | 
					 | 
				
			||||||
                query_embedding = deserialize(
 | 
					 | 
				
			||||||
                    row[0]
 | 
					 | 
				
			||||||
                )  # Deserialize the thumbnail embedding
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                # If no embedding found, generate it and return it
 | 
					 | 
				
			||||||
                thumbnail = base64.b64decode(query.thumbnail)
 | 
					 | 
				
			||||||
                query_embedding = self.upsert_thumbnail(query.id, thumbnail)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            query_embedding = self.text_embedding([query])[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        sql_query = """
 | 
					 | 
				
			||||||
            SELECT
 | 
					 | 
				
			||||||
                id,
 | 
					 | 
				
			||||||
                distance
 | 
					 | 
				
			||||||
            FROM vec_thumbnails
 | 
					 | 
				
			||||||
            WHERE thumbnail_embedding MATCH ?
 | 
					 | 
				
			||||||
                AND k = 100
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Add the IN clause if event_ids is provided and not empty
 | 
					 | 
				
			||||||
        # this is the only filter supported by sqlite-vec as of 0.1.3
 | 
					 | 
				
			||||||
        # but it seems to be broken in this version
 | 
					 | 
				
			||||||
        if event_ids:
 | 
					 | 
				
			||||||
            sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # order by distance DESC is not implemented in this version of sqlite-vec
 | 
					 | 
				
			||||||
        # when it's implemented, we can use cosine similarity
 | 
					 | 
				
			||||||
        sql_query += " ORDER BY distance"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        parameters = (
 | 
					 | 
				
			||||||
            [serialize(query_embedding)] + event_ids
 | 
					 | 
				
			||||||
            if event_ids
 | 
					 | 
				
			||||||
            else [serialize(query_embedding)]
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        results = self.db.execute_sql(sql_query, parameters).fetchall()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return results
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def search_description(
 | 
					 | 
				
			||||||
        self, query_text: str, event_ids: List[str] = None
 | 
					 | 
				
			||||||
    ) -> List[Tuple[str, float]]:
 | 
					 | 
				
			||||||
        query_embedding = self.text_embedding([query_text])[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Prepare the base SQL query
 | 
					 | 
				
			||||||
        sql_query = """
 | 
					 | 
				
			||||||
            SELECT
 | 
					 | 
				
			||||||
                id,
 | 
					 | 
				
			||||||
                distance
 | 
					 | 
				
			||||||
            FROM vec_descriptions
 | 
					 | 
				
			||||||
            WHERE description_embedding MATCH ?
 | 
					 | 
				
			||||||
                AND k = 100
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Add the IN clause if event_ids is provided and not empty
 | 
					 | 
				
			||||||
        # this is the only filter supported by sqlite-vec as of 0.1.3
 | 
					 | 
				
			||||||
        # but it seems to be broken in this version
 | 
					 | 
				
			||||||
        if event_ids:
 | 
					 | 
				
			||||||
            sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # order by distance DESC is not implemented in this version of sqlite-vec
 | 
					 | 
				
			||||||
        # when it's implemented, we can use cosine similarity
 | 
					 | 
				
			||||||
        sql_query += " ORDER BY distance"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        parameters = (
 | 
					 | 
				
			||||||
            [serialize(query_embedding)] + event_ids
 | 
					 | 
				
			||||||
            if event_ids
 | 
					 | 
				
			||||||
            else [serialize(query_embedding)]
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        results = self.db.execute_sql(sql_query, parameters).fetchall()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return results
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def reindex(self) -> None:
 | 
					    def reindex(self) -> None:
 | 
				
			||||||
        logger.info("Indexing event embeddings...")
 | 
					        logger.info("Indexing event embeddings...")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -12,6 +12,7 @@ import numpy as np
 | 
				
			|||||||
from peewee import DoesNotExist
 | 
					from peewee import DoesNotExist
 | 
				
			||||||
from playhouse.sqliteq import SqliteQueueDatabase
 | 
					from playhouse.sqliteq import SqliteQueueDatabase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsResponder
 | 
				
			||||||
from frigate.comms.event_metadata_updater import (
 | 
					from frigate.comms.event_metadata_updater import (
 | 
				
			||||||
    EventMetadataSubscriber,
 | 
					    EventMetadataSubscriber,
 | 
				
			||||||
    EventMetadataTypeEnum,
 | 
					    EventMetadataTypeEnum,
 | 
				
			||||||
@ -23,6 +24,7 @@ from frigate.const import CLIPS_DIR, UPDATE_EVENT_DESCRIPTION
 | 
				
			|||||||
from frigate.events.types import EventTypeEnum
 | 
					from frigate.events.types import EventTypeEnum
 | 
				
			||||||
from frigate.genai import get_genai_client
 | 
					from frigate.genai import get_genai_client
 | 
				
			||||||
from frigate.models import Event
 | 
					from frigate.models import Event
 | 
				
			||||||
 | 
					from frigate.util.builtin import serialize
 | 
				
			||||||
from frigate.util.image import SharedMemoryFrameManager, calculate_region
 | 
					from frigate.util.image import SharedMemoryFrameManager, calculate_region
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .embeddings import Embeddings
 | 
					from .embeddings import Embeddings
 | 
				
			||||||
@ -48,6 +50,7 @@ class EmbeddingMaintainer(threading.Thread):
 | 
				
			|||||||
        self.event_metadata_subscriber = EventMetadataSubscriber(
 | 
					        self.event_metadata_subscriber = EventMetadataSubscriber(
 | 
				
			||||||
            EventMetadataTypeEnum.regenerate_description
 | 
					            EventMetadataTypeEnum.regenerate_description
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        self.embeddings_responder = EmbeddingsResponder()
 | 
				
			||||||
        self.frame_manager = SharedMemoryFrameManager()
 | 
					        self.frame_manager = SharedMemoryFrameManager()
 | 
				
			||||||
        # create communication for updating event descriptions
 | 
					        # create communication for updating event descriptions
 | 
				
			||||||
        self.requestor = InterProcessRequestor()
 | 
					        self.requestor = InterProcessRequestor()
 | 
				
			||||||
@ -58,6 +61,7 @@ class EmbeddingMaintainer(threading.Thread):
 | 
				
			|||||||
    def run(self) -> None:
 | 
					    def run(self) -> None:
 | 
				
			||||||
        """Maintain a SQLite-vec database for semantic search."""
 | 
					        """Maintain a SQLite-vec database for semantic search."""
 | 
				
			||||||
        while not self.stop_event.is_set():
 | 
					        while not self.stop_event.is_set():
 | 
				
			||||||
 | 
					            self._process_requests()
 | 
				
			||||||
            self._process_updates()
 | 
					            self._process_updates()
 | 
				
			||||||
            self._process_finalized()
 | 
					            self._process_finalized()
 | 
				
			||||||
            self._process_event_metadata()
 | 
					            self._process_event_metadata()
 | 
				
			||||||
@ -65,9 +69,30 @@ class EmbeddingMaintainer(threading.Thread):
 | 
				
			|||||||
        self.event_subscriber.stop()
 | 
					        self.event_subscriber.stop()
 | 
				
			||||||
        self.event_end_subscriber.stop()
 | 
					        self.event_end_subscriber.stop()
 | 
				
			||||||
        self.event_metadata_subscriber.stop()
 | 
					        self.event_metadata_subscriber.stop()
 | 
				
			||||||
 | 
					        self.embeddings_responder.stop()
 | 
				
			||||||
        self.requestor.stop()
 | 
					        self.requestor.stop()
 | 
				
			||||||
        logger.info("Exiting embeddings maintenance...")
 | 
					        logger.info("Exiting embeddings maintenance...")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _process_requests(self) -> None:
 | 
				
			||||||
 | 
					        """Process embeddings requests"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def handle_request(topic: str, data: str) -> str:
 | 
				
			||||||
 | 
					            if topic == EmbeddingsRequestEnum.embed_description.value:
 | 
				
			||||||
 | 
					                return serialize(
 | 
				
			||||||
 | 
					                    self.embeddings.upsert_description(data["id"], data["description"]),
 | 
				
			||||||
 | 
					                    pack=False,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
 | 
				
			||||||
 | 
					                thumbnail = base64.b64decode(data["thumbnail"])
 | 
				
			||||||
 | 
					                return serialize(
 | 
				
			||||||
 | 
					                    self.embeddings.upsert_thumbnail(data["id"], thumbnail),
 | 
				
			||||||
 | 
					                    pack=False,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            elif topic == EmbeddingsRequestEnum.generate_search.value:
 | 
				
			||||||
 | 
					                return serialize(self.embeddings.text_embedding([data])[0], pack=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.embeddings_responder.check_for_request(handle_request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _process_updates(self) -> None:
 | 
					    def _process_updates(self) -> None:
 | 
				
			||||||
        """Process event updates"""
 | 
					        """Process event updates"""
 | 
				
			||||||
        update = self.event_subscriber.check_for_update()
 | 
					        update = self.event_subscriber.check_for_update()
 | 
				
			||||||
 | 
				
			|||||||
@ -8,10 +8,11 @@ import multiprocessing as mp
 | 
				
			|||||||
import queue
 | 
					import queue
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
import shlex
 | 
					import shlex
 | 
				
			||||||
 | 
					import struct
 | 
				
			||||||
import urllib.parse
 | 
					import urllib.parse
 | 
				
			||||||
from collections.abc import Mapping
 | 
					from collections.abc import Mapping
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from typing import Any, Optional, Tuple
 | 
					from typing import Any, Optional, Tuple, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import pytz
 | 
					import pytz
 | 
				
			||||||
@ -342,3 +343,32 @@ def generate_color_palette(n):
 | 
				
			|||||||
        colors.append(interpolate(color1, color2, factor))
 | 
					        colors.append(interpolate(color1, color2, factor))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return colors
 | 
					    return colors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def serialize(
 | 
				
			||||||
 | 
					    vector: Union[list[float], np.ndarray, float], pack: bool = True
 | 
				
			||||||
 | 
					) -> bytes:
 | 
				
			||||||
 | 
					    """Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
 | 
				
			||||||
 | 
					    if isinstance(vector, np.ndarray):
 | 
				
			||||||
 | 
					        # Convert numpy array to list of floats
 | 
				
			||||||
 | 
					        vector = vector.flatten().tolist()
 | 
				
			||||||
 | 
					    elif isinstance(vector, (float, np.float32, np.float64)):
 | 
				
			||||||
 | 
					        # Handle single float values
 | 
				
			||||||
 | 
					        vector = [vector]
 | 
				
			||||||
 | 
					    elif not isinstance(vector, list):
 | 
				
			||||||
 | 
					        raise TypeError(
 | 
				
			||||||
 | 
					            f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        if pack:
 | 
				
			||||||
 | 
					            return struct.pack("%sf" % len(vector), *vector)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return vector
 | 
				
			||||||
 | 
					    except struct.error as e:
 | 
				
			||||||
 | 
					        raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def deserialize(bytes_data: bytes) -> list[float]:
 | 
				
			||||||
 | 
					    """Deserializes a compact "raw bytes" format into a list of floats"""
 | 
				
			||||||
 | 
					    return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user