blakeblackshear.frigate/frigate/data_processing/post/semantic_trigger.py
Josh Hawkins 57bb0cc397
Semantic Search Triggers (#18969)
* semantic trigger test

* database and model

* config

* embeddings maintainer and trigger post-processor

* api to create, edit, delete triggers

* frontend and i18n keys

* use thumbnail and description for trigger types

* image picker tweaks

* initial sync

* thumbnail file management

* clean up logs and use saved thumbnail on frontend

* publish mqtt messages

* webpush changes to enable trigger notifications

* add enabled switch

* add triggers from explore

* renaming and deletion fixes

* fix typing

* UI updates and add last triggering event time and link

* log exception instead of return in endpoint

* highlight entry in UI when triggered

* save and delete thumbnails directly

* remove alert action for now and add descriptions

* tweaks

* clean up

* fix types

* docs

* docs tweaks

* docs

* reuse enum
2025-07-07 09:03:57 -05:00

234 lines
8.0 KiB
Python

"""Post time processor to trigger actions based on similar embeddings."""
import datetime
import json
import logging
import os
from typing import Any
import cv2
import numpy as np
from peewee import DoesNotExist
from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR
from frigate.data_processing.types import PostProcessDataEnum
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.embeddings.util import ZScoreNormalization
from frigate.models import Event, Trigger
from frigate.util.builtin import cosine_distance
from frigate.util.path import get_event_thumbnail_bytes
from ..post.api import PostProcessorApi
from ..types import DataProcessorMetrics
logger = logging.getLogger(__name__)
WRITE_DEBUG_IMAGES = False
class SemanticTriggerProcessor(PostProcessorApi):
def __init__(
self,
db: SqliteVecQueueDatabase,
config: FrigateConfig,
requestor: InterProcessRequestor,
metrics: DataProcessorMetrics,
embeddings,
):
super().__init__(config, metrics, None)
self.db = db
self.embeddings = embeddings
self.requestor = requestor
self.trigger_embeddings: list[np.ndarray] = []
self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization()
# load stats from disk
try:
with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "r") as f:
data = json.loads(f.read())
self.thumb_stats.from_dict(data["thumb_stats"])
self.desc_stats.from_dict(data["desc_stats"])
except FileNotFoundError:
pass
def process_data(
self, data: dict[str, Any], data_type: PostProcessDataEnum
) -> None:
event_id = data["event_id"]
camera = data["camera"]
process_type = data["type"]
if self.config.cameras[camera].semantic_search.triggers is None:
return
triggers = (
Trigger.select(
Trigger.camera,
Trigger.name,
Trigger.data,
Trigger.type,
Trigger.embedding,
Trigger.threshold,
)
.where(Trigger.camera == camera)
.dicts()
.iterator()
)
for trigger in triggers:
if (
trigger["name"]
not in self.config.cameras[camera].semantic_search.triggers
or not self.config.cameras[camera]
.semantic_search.triggers[trigger["name"]]
.enabled
):
logger.debug(
f"Trigger {trigger['name']} is disabled for camera {camera}"
)
continue
logger.debug(
f"Processing {trigger['type']} trigger for {event_id} on {trigger['camera']}: {trigger['name']}"
)
trigger_embedding = np.frombuffer(trigger["embedding"], dtype=np.float32)
# Get embeddings based on type
thumbnail_embedding = None
description_embedding = None
if process_type == "image":
cursor = self.db.execute_sql(
"""
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
""",
[event_id],
)
row = cursor.fetchone() if cursor else None
if row:
thumbnail_embedding = np.frombuffer(row[0], dtype=np.float32)
if process_type == "text":
cursor = self.db.execute_sql(
"""
SELECT description_embedding FROM vec_descriptions WHERE id = ?
""",
[event_id],
)
row = cursor.fetchone() if cursor else None
if row:
description_embedding = np.frombuffer(row[0], dtype=np.float32)
# Skip processing if we don't have any embeddings
if thumbnail_embedding is None and description_embedding is None:
logger.debug(f"No embeddings found for {event_id}")
return
# Determine which embedding to compare based on trigger type
if (
trigger["type"] in ["text", "thumbnail"]
and thumbnail_embedding is not None
):
data_embedding = thumbnail_embedding
normalized_distance = self.thumb_stats.normalize(
[cosine_distance(data_embedding, trigger_embedding)],
save_stats=False,
)[0]
elif trigger["type"] == "description" and description_embedding is not None:
data_embedding = description_embedding
normalized_distance = self.desc_stats.normalize(
[cosine_distance(data_embedding, trigger_embedding)],
save_stats=False,
)[0]
else:
continue
similarity = 1 - normalized_distance
logger.debug(
f"Trigger {trigger['name']} ({trigger['data'] if trigger['type'] == 'text' or trigger['type'] == 'description' else 'image'}): "
f"normalized distance: {normalized_distance:.4f}, "
f"similarity: {similarity:.4f}, threshold: {trigger['threshold']}"
)
# Check if similarity meets threshold
if similarity >= trigger["threshold"]:
logger.info(
f"Trigger {trigger['name']} activated with similarity {similarity:.4f}"
)
# Update the trigger's last_triggered and triggering_event_id
Trigger.update(
last_triggered=datetime.datetime.now(), triggering_event_id=event_id
).where(
Trigger.camera == camera, Trigger.name == trigger["name"]
).execute()
# Always publish MQTT message
self.requestor.send_data(
"triggers",
json.dumps(
{
"name": trigger["name"],
"camera": camera,
"event_id": event_id,
"type": trigger["type"],
"score": similarity,
}
),
)
if (
self.config.cameras[camera]
.semantic_search.triggers[trigger["name"]]
.actions
):
# TODO: handle actions for the trigger
# notifications already handled by webpush
pass
if WRITE_DEBUG_IMAGES:
try:
event: Event = Event.get(Event.id == event_id)
except DoesNotExist:
return
# Skip the event if not an object
if event.data.get("type") != "object":
return
thumbnail_bytes = get_event_thumbnail_bytes(event)
nparr = np.frombuffer(thumbnail_bytes, np.uint8)
thumbnail = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
font_scale = 0.5
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(
thumbnail,
f"{similarity:.4f}",
(10, 30),
font,
fontScale=font_scale,
color=(0, 255, 0),
thickness=2,
)
current_time = int(datetime.datetime.now().timestamp())
cv2.imwrite(
f"debug/frames/trigger-{event_id}_{current_time}.jpg",
thumbnail,
)
def handle_request(self, topic, request_data):
return None
def expire_object(self, object_id, camera):
pass