mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-30 19:09:13 +01:00
336 lines
12 KiB
Python
336 lines
12 KiB
Python
"""Maintain embeddings in SQLite-vec."""
|
|
|
|
import base64
|
|
import logging
|
|
import os
|
|
import threading
|
|
from multiprocessing.synchronize import Event as MpEvent
|
|
from typing import Optional
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from peewee import DoesNotExist
|
|
from playhouse.sqliteq import SqliteQueueDatabase
|
|
|
|
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsResponder
|
|
from frigate.comms.event_metadata_updater import (
|
|
EventMetadataSubscriber,
|
|
EventMetadataTypeEnum,
|
|
)
|
|
from frigate.comms.events_updater import EventEndSubscriber, EventUpdateSubscriber
|
|
from frigate.comms.inter_process import InterProcessRequestor
|
|
from frigate.config import FrigateConfig
|
|
from frigate.const import CLIPS_DIR, UPDATE_EVENT_DESCRIPTION
|
|
from frigate.events.types import EventTypeEnum
|
|
from frigate.genai import get_genai_client
|
|
from frigate.models import Event
|
|
from frigate.util.builtin import serialize
|
|
from frigate.util.image import SharedMemoryFrameManager, calculate_region
|
|
|
|
from .embeddings import Embeddings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmbeddingMaintainer(threading.Thread):
|
|
"""Handle embedding queue and post event updates."""
|
|
|
|
def __init__(
|
|
self,
|
|
db: SqliteQueueDatabase,
|
|
config: FrigateConfig,
|
|
stop_event: MpEvent,
|
|
) -> None:
|
|
super().__init__(name="embeddings_maintainer")
|
|
self.config = config
|
|
self.embeddings = Embeddings(config.semantic_search, db)
|
|
|
|
# Check if we need to re-index events
|
|
if config.semantic_search.reindex:
|
|
self.embeddings.reindex()
|
|
|
|
self.event_subscriber = EventUpdateSubscriber()
|
|
self.event_end_subscriber = EventEndSubscriber()
|
|
self.event_metadata_subscriber = EventMetadataSubscriber(
|
|
EventMetadataTypeEnum.regenerate_description
|
|
)
|
|
self.embeddings_responder = EmbeddingsResponder()
|
|
self.frame_manager = SharedMemoryFrameManager()
|
|
# create communication for updating event descriptions
|
|
self.requestor = InterProcessRequestor()
|
|
self.stop_event = stop_event
|
|
self.tracked_events = {}
|
|
self.genai_client = get_genai_client(config.genai)
|
|
|
|
def run(self) -> None:
|
|
"""Maintain a SQLite-vec database for semantic search."""
|
|
while not self.stop_event.is_set():
|
|
self._process_requests()
|
|
self._process_updates()
|
|
self._process_finalized()
|
|
self._process_event_metadata()
|
|
|
|
self.event_subscriber.stop()
|
|
self.event_end_subscriber.stop()
|
|
self.event_metadata_subscriber.stop()
|
|
self.embeddings_responder.stop()
|
|
self.requestor.stop()
|
|
logger.info("Exiting embeddings maintenance...")
|
|
|
|
def _process_requests(self) -> None:
|
|
"""Process embeddings requests"""
|
|
|
|
def _handle_request(topic: str, data: str) -> str:
|
|
try:
|
|
if topic == EmbeddingsRequestEnum.embed_description.value:
|
|
return serialize(
|
|
self.embeddings.upsert_description(
|
|
data["id"], data["description"]
|
|
),
|
|
pack=False,
|
|
)
|
|
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
|
|
thumbnail = base64.b64decode(data["thumbnail"])
|
|
return serialize(
|
|
self.embeddings.upsert_thumbnail(data["id"], thumbnail),
|
|
pack=False,
|
|
)
|
|
elif topic == EmbeddingsRequestEnum.generate_search.value:
|
|
return serialize(
|
|
self.embeddings.text_embedding([data])[0], pack=False
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Unable to handle embeddings request {e}")
|
|
|
|
self.embeddings_responder.check_for_request(_handle_request)
|
|
|
|
def _process_updates(self) -> None:
|
|
"""Process event updates"""
|
|
update = self.event_subscriber.check_for_update(timeout=0.1)
|
|
|
|
if update is None:
|
|
return
|
|
|
|
source_type, _, camera, data = update
|
|
|
|
if not camera or source_type != EventTypeEnum.tracked_object:
|
|
return
|
|
|
|
camera_config = self.config.cameras[camera]
|
|
if data["id"] not in self.tracked_events:
|
|
self.tracked_events[data["id"]] = []
|
|
|
|
# Create our own thumbnail based on the bounding box and the frame time
|
|
try:
|
|
frame_id = f"{camera}{data['frame_time']}"
|
|
yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv)
|
|
|
|
if yuv_frame is not None:
|
|
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
|
|
self.tracked_events[data["id"]].append(data)
|
|
self.frame_manager.close(frame_id)
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
def _process_finalized(self) -> None:
|
|
"""Process the end of an event."""
|
|
while True:
|
|
ended = self.event_end_subscriber.check_for_update(timeout=0.1)
|
|
|
|
if ended == None:
|
|
break
|
|
|
|
event_id, camera, updated_db = ended
|
|
camera_config = self.config.cameras[camera]
|
|
|
|
if updated_db:
|
|
try:
|
|
event: Event = Event.get(Event.id == event_id)
|
|
except DoesNotExist:
|
|
continue
|
|
|
|
# Skip the event if not an object
|
|
if event.data.get("type") != "object":
|
|
continue
|
|
|
|
# Extract valid thumbnail
|
|
thumbnail = base64.b64decode(event.thumbnail)
|
|
|
|
# Embed the thumbnail
|
|
self._embed_thumbnail(event_id, thumbnail)
|
|
|
|
if (
|
|
camera_config.genai.enabled
|
|
and self.genai_client is not None
|
|
and event.data.get("description") is None
|
|
and (
|
|
not camera_config.genai.objects
|
|
or event.label in camera_config.genai.objects
|
|
)
|
|
and (
|
|
not camera_config.genai.required_zones
|
|
or set(event.zones) & set(camera_config.genai.required_zones)
|
|
)
|
|
):
|
|
if event.has_snapshot and camera_config.genai.use_snapshot:
|
|
with open(
|
|
os.path.join(CLIPS_DIR, f"{event.camera}-{event.id}.jpg"),
|
|
"rb",
|
|
) as image_file:
|
|
snapshot_image = image_file.read()
|
|
|
|
img = cv2.imdecode(
|
|
np.frombuffer(snapshot_image, dtype=np.int8),
|
|
cv2.IMREAD_COLOR,
|
|
)
|
|
|
|
# crop snapshot based on region before sending off to genai
|
|
height, width = img.shape[:2]
|
|
x1_rel, y1_rel, width_rel, height_rel = event.data["region"]
|
|
|
|
x1, y1 = int(x1_rel * width), int(y1_rel * height)
|
|
cropped_image = img[
|
|
y1 : y1 + int(height_rel * height),
|
|
x1 : x1 + int(width_rel * width),
|
|
]
|
|
|
|
_, buffer = cv2.imencode(".jpg", cropped_image)
|
|
snapshot_image = buffer.tobytes()
|
|
|
|
embed_image = (
|
|
[snapshot_image]
|
|
if event.has_snapshot and camera_config.genai.use_snapshot
|
|
else (
|
|
[thumbnail for data in self.tracked_events[event_id]]
|
|
if len(self.tracked_events.get(event_id, [])) > 0
|
|
else [thumbnail]
|
|
)
|
|
)
|
|
|
|
# Generate the description. Call happens in a thread since it is network bound.
|
|
threading.Thread(
|
|
target=self._embed_description,
|
|
name=f"_embed_description_{event.id}",
|
|
daemon=True,
|
|
args=(
|
|
event,
|
|
embed_image,
|
|
),
|
|
).start()
|
|
|
|
# Delete tracked events based on the event_id
|
|
if event_id in self.tracked_events:
|
|
del self.tracked_events[event_id]
|
|
|
|
def _process_event_metadata(self):
|
|
# Check for regenerate description requests
|
|
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update(
|
|
timeout=0.1
|
|
)
|
|
|
|
if topic is None:
|
|
return
|
|
|
|
if event_id:
|
|
self.handle_regenerate_description(event_id, source)
|
|
|
|
def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]:
|
|
"""Return jpg thumbnail of a region of the frame."""
|
|
frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420)
|
|
region = calculate_region(
|
|
frame.shape, box[0], box[1], box[2], box[3], height, multiplier=1.4
|
|
)
|
|
frame = frame[region[1] : region[3], region[0] : region[2]]
|
|
width = int(height * frame.shape[1] / frame.shape[0])
|
|
frame = cv2.resize(frame, dsize=(width, height), interpolation=cv2.INTER_AREA)
|
|
ret, jpg = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
|
|
|
if ret:
|
|
return jpg.tobytes()
|
|
|
|
return None
|
|
|
|
def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
|
|
"""Embed the thumbnail for an event."""
|
|
self.embeddings.upsert_thumbnail(event_id, thumbnail)
|
|
|
|
def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None:
|
|
"""Embed the description for an event."""
|
|
camera_config = self.config.cameras[event.camera]
|
|
|
|
description = self.genai_client.generate_description(
|
|
camera_config, thumbnails, event
|
|
)
|
|
|
|
if not description:
|
|
logger.debug("Failed to generate description for %s", event.id)
|
|
return
|
|
|
|
# fire and forget description update
|
|
self.requestor.send_data(
|
|
UPDATE_EVENT_DESCRIPTION,
|
|
{"id": event.id, "description": description},
|
|
)
|
|
|
|
# Encode the description
|
|
self.embeddings.upsert_description(event.id, description)
|
|
|
|
logger.debug(
|
|
"Generated description for %s (%d images): %s",
|
|
event.id,
|
|
len(thumbnails),
|
|
description,
|
|
)
|
|
|
|
def handle_regenerate_description(self, event_id: str, source: str) -> None:
|
|
try:
|
|
event: Event = Event.get(Event.id == event_id)
|
|
except DoesNotExist:
|
|
logger.error(f"Event {event_id} not found for description regeneration")
|
|
return
|
|
|
|
camera_config = self.config.cameras[event.camera]
|
|
if not camera_config.genai.enabled or self.genai_client is None:
|
|
logger.error(f"GenAI not enabled for camera {event.camera}")
|
|
return
|
|
|
|
thumbnail = base64.b64decode(event.thumbnail)
|
|
|
|
logger.debug(
|
|
f"Trying {source} regeneration for {event}, has_snapshot: {event.has_snapshot}"
|
|
)
|
|
|
|
if event.has_snapshot and source == "snapshot":
|
|
with open(
|
|
os.path.join(CLIPS_DIR, f"{event.camera}-{event.id}.jpg"),
|
|
"rb",
|
|
) as image_file:
|
|
snapshot_image = image_file.read()
|
|
img = cv2.imdecode(
|
|
np.frombuffer(snapshot_image, dtype=np.int8), cv2.IMREAD_COLOR
|
|
)
|
|
|
|
# crop snapshot based on region before sending off to genai
|
|
height, width = img.shape[:2]
|
|
x1_rel, y1_rel, width_rel, height_rel = event.data["region"]
|
|
|
|
x1, y1 = int(x1_rel * width), int(y1_rel * height)
|
|
cropped_image = img[
|
|
y1 : y1 + int(height_rel * height), x1 : x1 + int(width_rel * width)
|
|
]
|
|
|
|
_, buffer = cv2.imencode(".jpg", cropped_image)
|
|
snapshot_image = buffer.tobytes()
|
|
|
|
embed_image = (
|
|
[snapshot_image]
|
|
if event.has_snapshot and source == "snapshot"
|
|
else (
|
|
[thumbnail for data in self.tracked_events[event_id]]
|
|
if len(self.tracked_events.get(event_id, [])) > 0
|
|
else [thumbnail]
|
|
)
|
|
)
|
|
|
|
self._embed_description(event, embed_image)
|