mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-21 19:07:46 +01:00
123 lines
3.7 KiB
Python
123 lines
3.7 KiB
Python
|
"""ChromaDB embeddings database."""
|
||
|
|
||
|
import base64
|
||
|
import io
|
||
|
import logging
|
||
|
import time
|
||
|
|
||
|
import numpy as np
|
||
|
from chromadb import Collection
|
||
|
from chromadb import HttpClient as ChromaClient
|
||
|
from chromadb.config import Settings
|
||
|
from PIL import Image
|
||
|
from playhouse.shortcuts import model_to_dict
|
||
|
|
||
|
from frigate.models import Event
|
||
|
|
||
|
from .functions.clip import ClipEmbedding
|
||
|
from .functions.minilm_l6_v2 import MiniLMEmbedding
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def get_metadata(event: Event) -> dict:
|
||
|
"""Extract valid event metadata."""
|
||
|
event_dict = model_to_dict(event)
|
||
|
return (
|
||
|
{
|
||
|
k: v
|
||
|
for k, v in event_dict.items()
|
||
|
if k not in ["id", "thumbnail"]
|
||
|
and v is not None
|
||
|
and isinstance(v, (str, int, float, bool))
|
||
|
}
|
||
|
| {
|
||
|
k: v
|
||
|
for k, v in event_dict["data"].items()
|
||
|
if k not in ["description"]
|
||
|
and v is not None
|
||
|
and isinstance(v, (str, int, float, bool))
|
||
|
}
|
||
|
| {
|
||
|
# Metadata search doesn't support $contains
|
||
|
# and an event can have multiple zones, so
|
||
|
# we need to create a key for each zone
|
||
|
f"{k}_{x}": True
|
||
|
for k, v in event_dict.items()
|
||
|
if isinstance(v, list) and len(v) > 0
|
||
|
for x in v
|
||
|
if isinstance(x, str)
|
||
|
}
|
||
|
)
|
||
|
|
||
|
|
||
|
class Embeddings:
|
||
|
"""ChromaDB embeddings database."""
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
self.client: ChromaClient = ChromaClient(
|
||
|
host="127.0.0.1",
|
||
|
settings=Settings(anonymized_telemetry=False),
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def thumbnail(self) -> Collection:
|
||
|
return self.client.get_or_create_collection(
|
||
|
name="event_thumbnail", embedding_function=ClipEmbedding()
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def description(self) -> Collection:
|
||
|
return self.client.get_or_create_collection(
|
||
|
name="event_description", embedding_function=MiniLMEmbedding()
|
||
|
)
|
||
|
|
||
|
def reindex(self) -> None:
|
||
|
"""Reindex all event embeddings."""
|
||
|
logger.info("Indexing event embeddings...")
|
||
|
self.client.reset()
|
||
|
|
||
|
st = time.time()
|
||
|
|
||
|
thumbnails = {"ids": [], "images": [], "metadatas": []}
|
||
|
descriptions = {"ids": [], "documents": [], "metadatas": []}
|
||
|
|
||
|
events = Event.select().where(
|
||
|
(Event.has_clip == True | Event.has_snapshot == True)
|
||
|
& Event.thumbnail.is_null(False)
|
||
|
)
|
||
|
|
||
|
event: Event
|
||
|
for event in events.iterator():
|
||
|
metadata = get_metadata(event)
|
||
|
thumbnail = base64.b64decode(event.thumbnail)
|
||
|
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
|
||
|
thumbnails["ids"].append(event.id)
|
||
|
thumbnails["images"].append(img)
|
||
|
thumbnails["metadatas"].append(metadata)
|
||
|
if event.data.get("description") is not None:
|
||
|
descriptions["ids"].append(event.id)
|
||
|
descriptions["documents"].append(event.data["description"])
|
||
|
descriptions["metadatas"].append(metadata)
|
||
|
|
||
|
if len(thumbnails["ids"]) > 0:
|
||
|
self.thumbnail.upsert(
|
||
|
images=thumbnails["images"],
|
||
|
metadatas=thumbnails["metadatas"],
|
||
|
ids=thumbnails["ids"],
|
||
|
)
|
||
|
|
||
|
if len(descriptions["ids"]) > 0:
|
||
|
self.description.upsert(
|
||
|
documents=descriptions["documents"],
|
||
|
metadatas=descriptions["metadatas"],
|
||
|
ids=descriptions["ids"],
|
||
|
)
|
||
|
|
||
|
logger.info(
|
||
|
"Embedded %d thumbnails and %d descriptions in %s seconds",
|
||
|
len(thumbnails["ids"]),
|
||
|
len(descriptions["ids"]),
|
||
|
time.time() - st,
|
||
|
)
|