mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-20 13:54:36 +01:00
Semantic Search Triggers (#18969)
* semantic trigger test * database and model * config * embeddings maintainer and trigger post-processor * api to create, edit, delete triggers * frontend and i18n keys * use thumbnail and description for trigger types * image picker tweaks * initial sync * thumbnail file management * clean up logs and use saved thumbnail on frontend * publish mqtt messages * webpush changes to enable trigger notifications * add enabled switch * add triggers from explore * renaming and deletion fixes * fix typing * UI updates and add last triggering event time and link * log exception instead of return in endpoint * highlight entry in UI when triggered * save and delete thumbnails directly * remove alert action for now and add descriptions * tweaks * clean up * fix types * docs * docs tweaks * docs * reuse enum
This commit is contained in:
committed by
Blake Blackshear
parent
28f816b49a
commit
3609b41217
233
frigate/data_processing/post/semantic_trigger.py
Normal file
233
frigate/data_processing/post/semantic_trigger.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Post time processor to trigger actions based on similar embeddings."""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from peewee import DoesNotExist
|
||||
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.const import CONFIG_DIR
|
||||
from frigate.data_processing.types import PostProcessDataEnum
|
||||
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.embeddings.util import ZScoreNormalization
|
||||
from frigate.models import Event, Trigger
|
||||
from frigate.util.builtin import cosine_distance
|
||||
from frigate.util.path import get_event_thumbnail_bytes
|
||||
|
||||
from ..post.api import PostProcessorApi
|
||||
from ..types import DataProcessorMetrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WRITE_DEBUG_IMAGES = False
|
||||
|
||||
|
||||
class SemanticTriggerProcessor(PostProcessorApi):
|
||||
def __init__(
|
||||
self,
|
||||
db: SqliteVecQueueDatabase,
|
||||
config: FrigateConfig,
|
||||
requestor: InterProcessRequestor,
|
||||
metrics: DataProcessorMetrics,
|
||||
embeddings,
|
||||
):
|
||||
super().__init__(config, metrics, None)
|
||||
self.db = db
|
||||
self.embeddings = embeddings
|
||||
self.requestor = requestor
|
||||
self.trigger_embeddings: list[np.ndarray] = []
|
||||
|
||||
self.thumb_stats = ZScoreNormalization()
|
||||
self.desc_stats = ZScoreNormalization()
|
||||
|
||||
# load stats from disk
|
||||
try:
|
||||
with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "r") as f:
|
||||
data = json.loads(f.read())
|
||||
self.thumb_stats.from_dict(data["thumb_stats"])
|
||||
self.desc_stats.from_dict(data["desc_stats"])
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def process_data(
|
||||
self, data: dict[str, Any], data_type: PostProcessDataEnum
|
||||
) -> None:
|
||||
event_id = data["event_id"]
|
||||
camera = data["camera"]
|
||||
process_type = data["type"]
|
||||
|
||||
if self.config.cameras[camera].semantic_search.triggers is None:
|
||||
return
|
||||
|
||||
triggers = (
|
||||
Trigger.select(
|
||||
Trigger.camera,
|
||||
Trigger.name,
|
||||
Trigger.data,
|
||||
Trigger.type,
|
||||
Trigger.embedding,
|
||||
Trigger.threshold,
|
||||
)
|
||||
.where(Trigger.camera == camera)
|
||||
.dicts()
|
||||
.iterator()
|
||||
)
|
||||
|
||||
for trigger in triggers:
|
||||
if (
|
||||
trigger["name"]
|
||||
not in self.config.cameras[camera].semantic_search.triggers
|
||||
or not self.config.cameras[camera]
|
||||
.semantic_search.triggers[trigger["name"]]
|
||||
.enabled
|
||||
):
|
||||
logger.debug(
|
||||
f"Trigger {trigger['name']} is disabled for camera {camera}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.debug(
|
||||
f"Processing {trigger['type']} trigger for {event_id} on {trigger['camera']}: {trigger['name']}"
|
||||
)
|
||||
|
||||
trigger_embedding = np.frombuffer(trigger["embedding"], dtype=np.float32)
|
||||
|
||||
# Get embeddings based on type
|
||||
thumbnail_embedding = None
|
||||
description_embedding = None
|
||||
|
||||
if process_type == "image":
|
||||
cursor = self.db.execute_sql(
|
||||
"""
|
||||
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
|
||||
""",
|
||||
[event_id],
|
||||
)
|
||||
row = cursor.fetchone() if cursor else None
|
||||
if row:
|
||||
thumbnail_embedding = np.frombuffer(row[0], dtype=np.float32)
|
||||
|
||||
if process_type == "text":
|
||||
cursor = self.db.execute_sql(
|
||||
"""
|
||||
SELECT description_embedding FROM vec_descriptions WHERE id = ?
|
||||
""",
|
||||
[event_id],
|
||||
)
|
||||
row = cursor.fetchone() if cursor else None
|
||||
if row:
|
||||
description_embedding = np.frombuffer(row[0], dtype=np.float32)
|
||||
|
||||
# Skip processing if we don't have any embeddings
|
||||
if thumbnail_embedding is None and description_embedding is None:
|
||||
logger.debug(f"No embeddings found for {event_id}")
|
||||
return
|
||||
|
||||
# Determine which embedding to compare based on trigger type
|
||||
if (
|
||||
trigger["type"] in ["text", "thumbnail"]
|
||||
and thumbnail_embedding is not None
|
||||
):
|
||||
data_embedding = thumbnail_embedding
|
||||
normalized_distance = self.thumb_stats.normalize(
|
||||
[cosine_distance(data_embedding, trigger_embedding)],
|
||||
save_stats=False,
|
||||
)[0]
|
||||
elif trigger["type"] == "description" and description_embedding is not None:
|
||||
data_embedding = description_embedding
|
||||
normalized_distance = self.desc_stats.normalize(
|
||||
[cosine_distance(data_embedding, trigger_embedding)],
|
||||
save_stats=False,
|
||||
)[0]
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
similarity = 1 - normalized_distance
|
||||
|
||||
logger.debug(
|
||||
f"Trigger {trigger['name']} ({trigger['data'] if trigger['type'] == 'text' or trigger['type'] == 'description' else 'image'}): "
|
||||
f"normalized distance: {normalized_distance:.4f}, "
|
||||
f"similarity: {similarity:.4f}, threshold: {trigger['threshold']}"
|
||||
)
|
||||
|
||||
# Check if similarity meets threshold
|
||||
if similarity >= trigger["threshold"]:
|
||||
logger.info(
|
||||
f"Trigger {trigger['name']} activated with similarity {similarity:.4f}"
|
||||
)
|
||||
|
||||
# Update the trigger's last_triggered and triggering_event_id
|
||||
Trigger.update(
|
||||
last_triggered=datetime.datetime.now(), triggering_event_id=event_id
|
||||
).where(
|
||||
Trigger.camera == camera, Trigger.name == trigger["name"]
|
||||
).execute()
|
||||
|
||||
# Always publish MQTT message
|
||||
self.requestor.send_data(
|
||||
"triggers",
|
||||
json.dumps(
|
||||
{
|
||||
"name": trigger["name"],
|
||||
"camera": camera,
|
||||
"event_id": event_id,
|
||||
"type": trigger["type"],
|
||||
"score": similarity,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
self.config.cameras[camera]
|
||||
.semantic_search.triggers[trigger["name"]]
|
||||
.actions
|
||||
):
|
||||
# TODO: handle actions for the trigger
|
||||
# notifications already handled by webpush
|
||||
pass
|
||||
|
||||
if WRITE_DEBUG_IMAGES:
|
||||
try:
|
||||
event: Event = Event.get(Event.id == event_id)
|
||||
except DoesNotExist:
|
||||
return
|
||||
|
||||
# Skip the event if not an object
|
||||
if event.data.get("type") != "object":
|
||||
return
|
||||
|
||||
thumbnail_bytes = get_event_thumbnail_bytes(event)
|
||||
|
||||
nparr = np.frombuffer(thumbnail_bytes, np.uint8)
|
||||
thumbnail = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
font_scale = 0.5
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
cv2.putText(
|
||||
thumbnail,
|
||||
f"{similarity:.4f}",
|
||||
(10, 30),
|
||||
font,
|
||||
fontScale=font_scale,
|
||||
color=(0, 255, 0),
|
||||
thickness=2,
|
||||
)
|
||||
|
||||
current_time = int(datetime.datetime.now().timestamp())
|
||||
cv2.imwrite(
|
||||
f"debug/frames/trigger-{event_id}_{current_time}.jpg",
|
||||
thumbnail,
|
||||
)
|
||||
|
||||
def handle_request(self, topic, request_data):
|
||||
return None
|
||||
|
||||
def expire_object(self, object_id, camera):
|
||||
pass
|
||||
Reference in New Issue
Block a user