mirror of
				https://github.com/blakeblackshear/frigate.git
				synced 2025-10-27 10:52:11 +01:00 
			
		
		
		
	Restructure embeddings (#14266)
* Restructure embeddings * Use ZMQ to proxy embeddings requests * Handle serialization * Formatting * Remove unused
This commit is contained in:
		
							parent
							
								
									a2ca18a714
								
							
						
					
					
						commit
						8ade85edec
					
				@ -472,7 +472,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
 | 
			
		||||
                status_code=404,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        thumb_result = context.embeddings.search_thumbnail(search_event)
 | 
			
		||||
        thumb_result = context.search_thumbnail(search_event)
 | 
			
		||||
        thumb_ids = dict(
 | 
			
		||||
            zip(
 | 
			
		||||
                [result[0] for result in thumb_result],
 | 
			
		||||
@ -487,7 +487,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
 | 
			
		||||
        search_types = search_type.split(",")
 | 
			
		||||
 | 
			
		||||
        if "thumbnail" in search_types:
 | 
			
		||||
            thumb_result = context.embeddings.search_thumbnail(query)
 | 
			
		||||
            thumb_result = context.search_thumbnail(query)
 | 
			
		||||
            thumb_ids = dict(
 | 
			
		||||
                zip(
 | 
			
		||||
                    [result[0] for result in thumb_result],
 | 
			
		||||
@ -504,7 +504,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if "description" in search_types:
 | 
			
		||||
            desc_result = context.embeddings.search_description(query)
 | 
			
		||||
            desc_result = context.search_description(query)
 | 
			
		||||
            desc_ids = dict(
 | 
			
		||||
                zip(
 | 
			
		||||
                    [result[0] for result in desc_result],
 | 
			
		||||
@ -944,9 +944,9 @@ def set_description(
 | 
			
		||||
    # If semantic search is enabled, update the index
 | 
			
		||||
    if request.app.frigate_config.semantic_search.enabled:
 | 
			
		||||
        context: EmbeddingsContext = request.app.embeddings
 | 
			
		||||
        context.embeddings.upsert_description(
 | 
			
		||||
            event_id=event_id,
 | 
			
		||||
            description=new_description,
 | 
			
		||||
        context.update_description(
 | 
			
		||||
            event_id,
 | 
			
		||||
            new_description,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    response_message = (
 | 
			
		||||
@ -1033,8 +1033,8 @@ def delete_event(request: Request, event_id: str):
 | 
			
		||||
    # If semantic search is enabled, update the index
 | 
			
		||||
    if request.app.frigate_config.semantic_search.enabled:
 | 
			
		||||
        context: EmbeddingsContext = request.app.embeddings
 | 
			
		||||
        context.embeddings.delete_thumbnail(id=[event_id])
 | 
			
		||||
        context.embeddings.delete_description(id=[event_id])
 | 
			
		||||
        context.db.delete_embeddings_thumbnail(id=[event_id])
 | 
			
		||||
        context.db.delete_embeddings_description(id=[event_id])
 | 
			
		||||
    return JSONResponse(
 | 
			
		||||
        content=({"success": True, "message": "Event " + event_id + " deleted"}),
 | 
			
		||||
        status_code=200,
 | 
			
		||||
 | 
			
		||||
@ -276,7 +276,7 @@ class FrigateApp:
 | 
			
		||||
    def init_embeddings_client(self) -> None:
 | 
			
		||||
        if self.config.semantic_search.enabled:
 | 
			
		||||
            # Create a client for other processes to use
 | 
			
		||||
            self.embeddings = EmbeddingsContext(self.config, self.db)
 | 
			
		||||
            self.embeddings = EmbeddingsContext(self.db)
 | 
			
		||||
 | 
			
		||||
    def init_external_event_processor(self) -> None:
 | 
			
		||||
        self.external_event_processor = ExternalEventProcessor(self.config)
 | 
			
		||||
@ -699,7 +699,7 @@ class FrigateApp:
 | 
			
		||||
 | 
			
		||||
        # Save embeddings stats to disk
 | 
			
		||||
        if self.embeddings:
 | 
			
		||||
            self.embeddings.save_stats()
 | 
			
		||||
            self.embeddings.stop()
 | 
			
		||||
 | 
			
		||||
        # Stop Communicators
 | 
			
		||||
        self.inter_process_communicator.stop()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										62
									
								
								frigate/comms/embeddings_updater.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								frigate/comms/embeddings_updater.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,62 @@
 | 
			
		||||
"""Facilitates communication between processes."""
 | 
			
		||||
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Callable
 | 
			
		||||
 | 
			
		||||
import zmq
 | 
			
		||||
 | 
			
		||||
SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EmbeddingsRequestEnum(Enum):
 | 
			
		||||
    embed_description = "embed_description"
 | 
			
		||||
    embed_thumbnail = "embed_thumbnail"
 | 
			
		||||
    generate_search = "generate_search"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EmbeddingsResponder:
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        self.context = zmq.Context()
 | 
			
		||||
        self.socket = self.context.socket(zmq.REP)
 | 
			
		||||
        self.socket.bind(SOCKET_REP_REQ)
 | 
			
		||||
 | 
			
		||||
    def check_for_request(self, process: Callable) -> None:
 | 
			
		||||
        while True:  # load all messages that are queued
 | 
			
		||||
            has_message, _, _ = zmq.select([self.socket], [], [], 1)
 | 
			
		||||
 | 
			
		||||
            if not has_message:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                (topic, value) = self.socket.recv_json(flags=zmq.NOBLOCK)
 | 
			
		||||
 | 
			
		||||
                response = process(topic, value)
 | 
			
		||||
 | 
			
		||||
                if response is not None:
 | 
			
		||||
                    self.socket.send_json(response)
 | 
			
		||||
                else:
 | 
			
		||||
                    self.socket.send_json([])
 | 
			
		||||
            except zmq.ZMQError:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
    def stop(self) -> None:
 | 
			
		||||
        self.socket.close()
 | 
			
		||||
        self.context.destroy()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EmbeddingsRequestor:
 | 
			
		||||
    """Simplifies sending data to EmbeddingsResponder and getting a reply."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        self.context = zmq.Context()
 | 
			
		||||
        self.socket = self.context.socket(zmq.REQ)
 | 
			
		||||
        self.socket.connect(SOCKET_REP_REQ)
 | 
			
		||||
 | 
			
		||||
    def send_data(self, topic: str, data: any) -> str:
 | 
			
		||||
        """Sends data and then waits for reply."""
 | 
			
		||||
        self.socket.send_json((topic, data))
 | 
			
		||||
        return self.socket.recv_json()
 | 
			
		||||
 | 
			
		||||
    def stop(self) -> None:
 | 
			
		||||
        self.socket.close()
 | 
			
		||||
        self.context.destroy()
 | 
			
		||||
@ -20,3 +20,11 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
 | 
			
		||||
        conn.enable_load_extension(True)
 | 
			
		||||
        conn.load_extension(self.sqlite_vec_path)
 | 
			
		||||
        conn.enable_load_extension(False)
 | 
			
		||||
 | 
			
		||||
    def delete_embeddings_thumbnail(self, event_ids: list[str]) -> None:
 | 
			
		||||
        ids = ",".join(["?" for _ in event_ids])
 | 
			
		||||
        self.execute_sql(f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids)
 | 
			
		||||
 | 
			
		||||
    def delete_embeddings_description(self, event_ids: list[str]) -> None:
 | 
			
		||||
        ids = ",".join(["?" for _ in event_ids])
 | 
			
		||||
        self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids)
 | 
			
		||||
 | 
			
		||||
@ -7,14 +7,16 @@ import os
 | 
			
		||||
import signal
 | 
			
		||||
import threading
 | 
			
		||||
from types import FrameType
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from typing import Optional, Union
 | 
			
		||||
 | 
			
		||||
from setproctitle import setproctitle
 | 
			
		||||
 | 
			
		||||
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor
 | 
			
		||||
from frigate.config import FrigateConfig
 | 
			
		||||
from frigate.const import CONFIG_DIR
 | 
			
		||||
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
 | 
			
		||||
from frigate.models import Event
 | 
			
		||||
from frigate.util.builtin import serialize
 | 
			
		||||
from frigate.util.services import listen
 | 
			
		||||
 | 
			
		||||
from .embeddings import Embeddings
 | 
			
		||||
@ -70,10 +72,11 @@ def manage_embeddings(config: FrigateConfig) -> None:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EmbeddingsContext:
 | 
			
		||||
    def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase):
 | 
			
		||||
        self.embeddings = Embeddings(config.semantic_search, db)
 | 
			
		||||
    def __init__(self, db: SqliteVecQueueDatabase):
 | 
			
		||||
        self.db = db
 | 
			
		||||
        self.thumb_stats = ZScoreNormalization()
 | 
			
		||||
        self.desc_stats = ZScoreNormalization()
 | 
			
		||||
        self.requestor = EmbeddingsRequestor()
 | 
			
		||||
 | 
			
		||||
        # load stats from disk
 | 
			
		||||
        try:
 | 
			
		||||
@ -84,7 +87,7 @@ class EmbeddingsContext:
 | 
			
		||||
        except FileNotFoundError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
    def save_stats(self):
 | 
			
		||||
    def stop(self):
 | 
			
		||||
        """Write the stats to disk as JSON on exit."""
 | 
			
		||||
        contents = {
 | 
			
		||||
            "thumb_stats": self.thumb_stats.to_dict(),
 | 
			
		||||
@ -92,3 +95,100 @@ class EmbeddingsContext:
 | 
			
		||||
        }
 | 
			
		||||
        with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
 | 
			
		||||
            json.dump(contents, f)
 | 
			
		||||
        self.requestor.stop()
 | 
			
		||||
 | 
			
		||||
    def search_thumbnail(
 | 
			
		||||
        self, query: Union[Event, str], event_ids: list[str] = None
 | 
			
		||||
    ) -> list[tuple[str, float]]:
 | 
			
		||||
        if query.__class__ == Event:
 | 
			
		||||
            cursor = self.db.execute_sql(
 | 
			
		||||
                """
 | 
			
		||||
                SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
 | 
			
		||||
                """,
 | 
			
		||||
                [query.id],
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            row = cursor.fetchone() if cursor else None
 | 
			
		||||
 | 
			
		||||
            if row:
 | 
			
		||||
                query_embedding = row[0]
 | 
			
		||||
            else:
 | 
			
		||||
                # If no embedding found, generate it and return it
 | 
			
		||||
                query_embedding = serialize(
 | 
			
		||||
                    self.requestor.send_data(
 | 
			
		||||
                        EmbeddingsRequestEnum.embed_thumbnail.value,
 | 
			
		||||
                        {"id": query.id, "thumbnail": query.thumbnail},
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
        else:
 | 
			
		||||
            query_embedding = serialize(
 | 
			
		||||
                self.requestor.send_data(
 | 
			
		||||
                    EmbeddingsRequestEnum.generate_search.value, query
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        sql_query = """
 | 
			
		||||
            SELECT
 | 
			
		||||
                id,
 | 
			
		||||
                distance
 | 
			
		||||
            FROM vec_thumbnails
 | 
			
		||||
            WHERE thumbnail_embedding MATCH ?
 | 
			
		||||
                AND k = 100
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # Add the IN clause if event_ids is provided and not empty
 | 
			
		||||
        # this is the only filter supported by sqlite-vec as of 0.1.3
 | 
			
		||||
        # but it seems to be broken in this version
 | 
			
		||||
        if event_ids:
 | 
			
		||||
            sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
 | 
			
		||||
 | 
			
		||||
        # order by distance DESC is not implemented in this version of sqlite-vec
 | 
			
		||||
        # when it's implemented, we can use cosine similarity
 | 
			
		||||
        sql_query += " ORDER BY distance"
 | 
			
		||||
 | 
			
		||||
        parameters = [query_embedding] + event_ids if event_ids else [query_embedding]
 | 
			
		||||
 | 
			
		||||
        results = self.db.execute_sql(sql_query, parameters).fetchall()
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
    def search_description(
 | 
			
		||||
        self, query_text: str, event_ids: list[str] = None
 | 
			
		||||
    ) -> list[tuple[str, float]]:
 | 
			
		||||
        query_embedding = serialize(
 | 
			
		||||
            self.requestor.send_data(
 | 
			
		||||
                EmbeddingsRequestEnum.generate_search.value, query_text
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Prepare the base SQL query
 | 
			
		||||
        sql_query = """
 | 
			
		||||
            SELECT
 | 
			
		||||
                id,
 | 
			
		||||
                distance
 | 
			
		||||
            FROM vec_descriptions
 | 
			
		||||
            WHERE description_embedding MATCH ?
 | 
			
		||||
                AND k = 100
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # Add the IN clause if event_ids is provided and not empty
 | 
			
		||||
        # this is the only filter supported by sqlite-vec as of 0.1.3
 | 
			
		||||
        # but it seems to be broken in this version
 | 
			
		||||
        if event_ids:
 | 
			
		||||
            sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
 | 
			
		||||
 | 
			
		||||
        # order by distance DESC is not implemented in this version of sqlite-vec
 | 
			
		||||
        # when it's implemented, we can use cosine similarity
 | 
			
		||||
        sql_query += " ORDER BY distance"
 | 
			
		||||
 | 
			
		||||
        parameters = [query_embedding] + event_ids if event_ids else [query_embedding]
 | 
			
		||||
 | 
			
		||||
        results = self.db.execute_sql(sql_query, parameters).fetchall()
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
    def update_description(self, event_id: str, description: str) -> None:
 | 
			
		||||
        self.requestor.send_data(
 | 
			
		||||
            EmbeddingsRequestEnum.embed_description.value,
 | 
			
		||||
            {"id": event_id, "description": description},
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -3,11 +3,8 @@
 | 
			
		||||
import base64
 | 
			
		||||
import io
 | 
			
		||||
import logging
 | 
			
		||||
import struct
 | 
			
		||||
import time
 | 
			
		||||
from typing import List, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from playhouse.shortcuts import model_to_dict
 | 
			
		||||
 | 
			
		||||
@ -17,6 +14,7 @@ from frigate.const import UPDATE_MODEL_STATE
 | 
			
		||||
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
 | 
			
		||||
from frigate.models import Event
 | 
			
		||||
from frigate.types import ModelStatusTypesEnum
 | 
			
		||||
from frigate.util.builtin import serialize
 | 
			
		||||
 | 
			
		||||
from .functions.onnx import GenericONNXEmbedding
 | 
			
		||||
 | 
			
		||||
@ -54,30 +52,6 @@ def get_metadata(event: Event) -> dict:
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def serialize(vector: Union[List[float], np.ndarray, float]) -> bytes:
 | 
			
		||||
    """Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
 | 
			
		||||
    if isinstance(vector, np.ndarray):
 | 
			
		||||
        # Convert numpy array to list of floats
 | 
			
		||||
        vector = vector.flatten().tolist()
 | 
			
		||||
    elif isinstance(vector, (float, np.float32, np.float64)):
 | 
			
		||||
        # Handle single float values
 | 
			
		||||
        vector = [vector]
 | 
			
		||||
    elif not isinstance(vector, list):
 | 
			
		||||
        raise TypeError(
 | 
			
		||||
            f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        return struct.pack("%sf" % len(vector), *vector)
 | 
			
		||||
    except struct.error as e:
 | 
			
		||||
        raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def deserialize(bytes_data: bytes) -> List[float]:
 | 
			
		||||
    """Deserializes a compact "raw bytes" format into a list of floats"""
 | 
			
		||||
    return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Embeddings:
 | 
			
		||||
    """SQLite-vec embeddings database."""
 | 
			
		||||
 | 
			
		||||
@ -190,106 +164,6 @@ class Embeddings:
 | 
			
		||||
 | 
			
		||||
        return embedding
 | 
			
		||||
 | 
			
		||||
    def delete_thumbnail(self, event_ids: List[str]) -> None:
 | 
			
		||||
        ids = ",".join(["?" for _ in event_ids])
 | 
			
		||||
        self.db.execute_sql(
 | 
			
		||||
            f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def delete_description(self, event_ids: List[str]) -> None:
 | 
			
		||||
        ids = ",".join(["?" for _ in event_ids])
 | 
			
		||||
        self.db.execute_sql(
 | 
			
		||||
            f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def search_thumbnail(
 | 
			
		||||
        self, query: Union[Event, str], event_ids: List[str] = None
 | 
			
		||||
    ) -> List[Tuple[str, float]]:
 | 
			
		||||
        if query.__class__ == Event:
 | 
			
		||||
            cursor = self.db.execute_sql(
 | 
			
		||||
                """
 | 
			
		||||
                SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
 | 
			
		||||
                """,
 | 
			
		||||
                [query.id],
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            row = cursor.fetchone() if cursor else None
 | 
			
		||||
 | 
			
		||||
            if row:
 | 
			
		||||
                query_embedding = deserialize(
 | 
			
		||||
                    row[0]
 | 
			
		||||
                )  # Deserialize the thumbnail embedding
 | 
			
		||||
            else:
 | 
			
		||||
                # If no embedding found, generate it and return it
 | 
			
		||||
                thumbnail = base64.b64decode(query.thumbnail)
 | 
			
		||||
                query_embedding = self.upsert_thumbnail(query.id, thumbnail)
 | 
			
		||||
        else:
 | 
			
		||||
            query_embedding = self.text_embedding([query])[0]
 | 
			
		||||
 | 
			
		||||
        sql_query = """
 | 
			
		||||
            SELECT
 | 
			
		||||
                id,
 | 
			
		||||
                distance
 | 
			
		||||
            FROM vec_thumbnails
 | 
			
		||||
            WHERE thumbnail_embedding MATCH ?
 | 
			
		||||
                AND k = 100
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # Add the IN clause if event_ids is provided and not empty
 | 
			
		||||
        # this is the only filter supported by sqlite-vec as of 0.1.3
 | 
			
		||||
        # but it seems to be broken in this version
 | 
			
		||||
        if event_ids:
 | 
			
		||||
            sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
 | 
			
		||||
 | 
			
		||||
        # order by distance DESC is not implemented in this version of sqlite-vec
 | 
			
		||||
        # when it's implemented, we can use cosine similarity
 | 
			
		||||
        sql_query += " ORDER BY distance"
 | 
			
		||||
 | 
			
		||||
        parameters = (
 | 
			
		||||
            [serialize(query_embedding)] + event_ids
 | 
			
		||||
            if event_ids
 | 
			
		||||
            else [serialize(query_embedding)]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        results = self.db.execute_sql(sql_query, parameters).fetchall()
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
    def search_description(
 | 
			
		||||
        self, query_text: str, event_ids: List[str] = None
 | 
			
		||||
    ) -> List[Tuple[str, float]]:
 | 
			
		||||
        query_embedding = self.text_embedding([query_text])[0]
 | 
			
		||||
 | 
			
		||||
        # Prepare the base SQL query
 | 
			
		||||
        sql_query = """
 | 
			
		||||
            SELECT
 | 
			
		||||
                id,
 | 
			
		||||
                distance
 | 
			
		||||
            FROM vec_descriptions
 | 
			
		||||
            WHERE description_embedding MATCH ?
 | 
			
		||||
                AND k = 100
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # Add the IN clause if event_ids is provided and not empty
 | 
			
		||||
        # this is the only filter supported by sqlite-vec as of 0.1.3
 | 
			
		||||
        # but it seems to be broken in this version
 | 
			
		||||
        if event_ids:
 | 
			
		||||
            sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
 | 
			
		||||
 | 
			
		||||
        # order by distance DESC is not implemented in this version of sqlite-vec
 | 
			
		||||
        # when it's implemented, we can use cosine similarity
 | 
			
		||||
        sql_query += " ORDER BY distance"
 | 
			
		||||
 | 
			
		||||
        parameters = (
 | 
			
		||||
            [serialize(query_embedding)] + event_ids
 | 
			
		||||
            if event_ids
 | 
			
		||||
            else [serialize(query_embedding)]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        results = self.db.execute_sql(sql_query, parameters).fetchall()
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
    def reindex(self) -> None:
 | 
			
		||||
        logger.info("Indexing event embeddings...")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -12,6 +12,7 @@ import numpy as np
 | 
			
		||||
from peewee import DoesNotExist
 | 
			
		||||
from playhouse.sqliteq import SqliteQueueDatabase
 | 
			
		||||
 | 
			
		||||
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsResponder
 | 
			
		||||
from frigate.comms.event_metadata_updater import (
 | 
			
		||||
    EventMetadataSubscriber,
 | 
			
		||||
    EventMetadataTypeEnum,
 | 
			
		||||
@ -23,6 +24,7 @@ from frigate.const import CLIPS_DIR, UPDATE_EVENT_DESCRIPTION
 | 
			
		||||
from frigate.events.types import EventTypeEnum
 | 
			
		||||
from frigate.genai import get_genai_client
 | 
			
		||||
from frigate.models import Event
 | 
			
		||||
from frigate.util.builtin import serialize
 | 
			
		||||
from frigate.util.image import SharedMemoryFrameManager, calculate_region
 | 
			
		||||
 | 
			
		||||
from .embeddings import Embeddings
 | 
			
		||||
@ -48,6 +50,7 @@ class EmbeddingMaintainer(threading.Thread):
 | 
			
		||||
        self.event_metadata_subscriber = EventMetadataSubscriber(
 | 
			
		||||
            EventMetadataTypeEnum.regenerate_description
 | 
			
		||||
        )
 | 
			
		||||
        self.embeddings_responder = EmbeddingsResponder()
 | 
			
		||||
        self.frame_manager = SharedMemoryFrameManager()
 | 
			
		||||
        # create communication for updating event descriptions
 | 
			
		||||
        self.requestor = InterProcessRequestor()
 | 
			
		||||
@ -58,6 +61,7 @@ class EmbeddingMaintainer(threading.Thread):
 | 
			
		||||
    def run(self) -> None:
 | 
			
		||||
        """Maintain a SQLite-vec database for semantic search."""
 | 
			
		||||
        while not self.stop_event.is_set():
 | 
			
		||||
            self._process_requests()
 | 
			
		||||
            self._process_updates()
 | 
			
		||||
            self._process_finalized()
 | 
			
		||||
            self._process_event_metadata()
 | 
			
		||||
@ -65,9 +69,30 @@ class EmbeddingMaintainer(threading.Thread):
 | 
			
		||||
        self.event_subscriber.stop()
 | 
			
		||||
        self.event_end_subscriber.stop()
 | 
			
		||||
        self.event_metadata_subscriber.stop()
 | 
			
		||||
        self.embeddings_responder.stop()
 | 
			
		||||
        self.requestor.stop()
 | 
			
		||||
        logger.info("Exiting embeddings maintenance...")
 | 
			
		||||
 | 
			
		||||
    def _process_requests(self) -> None:
 | 
			
		||||
        """Process embeddings requests"""
 | 
			
		||||
 | 
			
		||||
        def handle_request(topic: str, data: str) -> str:
 | 
			
		||||
            if topic == EmbeddingsRequestEnum.embed_description.value:
 | 
			
		||||
                return serialize(
 | 
			
		||||
                    self.embeddings.upsert_description(data["id"], data["description"]),
 | 
			
		||||
                    pack=False,
 | 
			
		||||
                )
 | 
			
		||||
            elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
 | 
			
		||||
                thumbnail = base64.b64decode(data["thumbnail"])
 | 
			
		||||
                return serialize(
 | 
			
		||||
                    self.embeddings.upsert_thumbnail(data["id"], thumbnail),
 | 
			
		||||
                    pack=False,
 | 
			
		||||
                )
 | 
			
		||||
            elif topic == EmbeddingsRequestEnum.generate_search.value:
 | 
			
		||||
                return serialize(self.embeddings.text_embedding([data])[0], pack=False)
 | 
			
		||||
 | 
			
		||||
        self.embeddings_responder.check_for_request(handle_request)
 | 
			
		||||
 | 
			
		||||
    def _process_updates(self) -> None:
 | 
			
		||||
        """Process event updates"""
 | 
			
		||||
        update = self.event_subscriber.check_for_update()
 | 
			
		||||
 | 
			
		||||
@ -8,10 +8,11 @@ import multiprocessing as mp
 | 
			
		||||
import queue
 | 
			
		||||
import re
 | 
			
		||||
import shlex
 | 
			
		||||
import struct
 | 
			
		||||
import urllib.parse
 | 
			
		||||
from collections.abc import Mapping
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any, Optional, Tuple
 | 
			
		||||
from typing import Any, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import pytz
 | 
			
		||||
@ -342,3 +343,32 @@ def generate_color_palette(n):
 | 
			
		||||
        colors.append(interpolate(color1, color2, factor))
 | 
			
		||||
 | 
			
		||||
    return colors
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def serialize(
 | 
			
		||||
    vector: Union[list[float], np.ndarray, float], pack: bool = True
 | 
			
		||||
) -> bytes:
 | 
			
		||||
    """Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
 | 
			
		||||
    if isinstance(vector, np.ndarray):
 | 
			
		||||
        # Convert numpy array to list of floats
 | 
			
		||||
        vector = vector.flatten().tolist()
 | 
			
		||||
    elif isinstance(vector, (float, np.float32, np.float64)):
 | 
			
		||||
        # Handle single float values
 | 
			
		||||
        vector = [vector]
 | 
			
		||||
    elif not isinstance(vector, list):
 | 
			
		||||
        raise TypeError(
 | 
			
		||||
            f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        if pack:
 | 
			
		||||
            return struct.pack("%sf" % len(vector), *vector)
 | 
			
		||||
        else:
 | 
			
		||||
            return vector
 | 
			
		||||
    except struct.error as e:
 | 
			
		||||
        raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def deserialize(bytes_data: bytes) -> list[float]:
 | 
			
		||||
    """Deserializes a compact "raw bytes" format into a list of floats"""
 | 
			
		||||
    return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user