diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index 9ee508823..d77a9eecf 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -1,13 +1,11 @@ """SQLite-vec embeddings database.""" import base64 -import io import logging import os import time from numpy import ndarray -from PIL import Image from playhouse.shortcuts import model_to_dict from frigate.comms.inter_process import InterProcessRequestor @@ -22,7 +20,7 @@ from frigate.models import Event from frigate.types import ModelStatusTypesEnum from frigate.util.builtin import serialize -from .functions.onnx import GenericONNXEmbedding +from .functions.onnx import GenericONNXEmbedding, ModelTypeEnum logger = logging.getLogger(__name__) @@ -97,7 +95,7 @@ class Embeddings: "text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx", }, model_size=config.model_size, - model_type="text", + model_type=ModelTypeEnum.text, requestor=self.requestor, device="CPU", ) @@ -118,83 +116,102 @@ class Embeddings: model_file=model_file, download_urls=download_urls, model_size=config.model_size, - model_type="vision", + model_type=ModelTypeEnum.vision, requestor=self.requestor, device="GPU" if config.model_size == "large" else "CPU", ) - def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray: - # Convert thumbnail bytes to PIL Image - image = Image.open(io.BytesIO(thumbnail)).convert("RGB") - embedding = self.vision_embedding([image])[0] + def embed_thumbnail( + self, event_id: str, thumbnail: bytes, upsert: bool = True + ) -> ndarray: + """Embed thumbnail and optionally insert into DB. - self.db.execute_sql( - """ - INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) - VALUES(?, ?) - """, - (event_id, serialize(embedding)), - ) + @param: event_id in Events DB + @param: thumbnail bytes in jpg format + @param: upsert If embedding should be upserted into vec DB + """ + # Convert thumbnail bytes to PIL Image + embedding = self.vision_embedding([thumbnail])[0] + + if upsert: + self.db.execute_sql( + """ + INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) + VALUES(?, ?) + """, + (event_id, serialize(embedding)), + ) return embedding - def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]: - images = [ - Image.open(io.BytesIO(thumb)).convert("RGB") - for thumb in event_thumbs.values() - ] + def batch_embed_thumbnail( + self, event_thumbs: dict[str, bytes], upsert: bool = True + ) -> list[ndarray]: + """Embed thumbnails and optionally insert into DB. + + @param: event_thumbs Map of Event IDs in DB to thumbnail bytes in jpg format + @param: upsert If embedding should be upserted into vec DB + """ ids = list(event_thumbs.keys()) - embeddings = self.vision_embedding(images) + embeddings = self.vision_embedding(list(event_thumbs.values())) - items = [] + if upsert: + items = [] - for i in range(len(ids)): - items.append(ids[i]) - items.append(serialize(embeddings[i])) + for i in range(len(ids)): + items.append(ids[i]) + items.append(serialize(embeddings[i])) + + self.db.execute_sql( + """ + INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) + VALUES {} + """.format(", ".join(["(?, ?)"] * len(ids))), + items, + ) - self.db.execute_sql( - """ - INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) - VALUES {} - """.format(", ".join(["(?, ?)"] * len(ids))), - items, - ) return embeddings - def upsert_description(self, event_id: str, description: str) -> ndarray: + def embed_description( + self, event_id: str, description: str, upsert: bool = True + ) -> ndarray: embedding = self.text_embedding([description])[0] - self.db.execute_sql( - """ - INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) - VALUES(?, ?) - """, - (event_id, serialize(embedding)), - ) + + if upsert: + self.db.execute_sql( + """ + INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) + VALUES(?, ?) + """, + (event_id, serialize(embedding)), + ) return embedding - def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray: + def batch_embed_description( + self, event_descriptions: dict[str, str], upsert: bool = True + ) -> ndarray: # upsert embeddings one by one to avoid token limit embeddings = [] for desc in event_descriptions.values(): embeddings.append(self.text_embedding([desc])[0]) - ids = list(event_descriptions.keys()) + if upsert: + ids = list(event_descriptions.keys()) + items = [] - items = [] + for i in range(len(ids)): + items.append(ids[i]) + items.append(serialize(embeddings[i])) - for i in range(len(ids)): - items.append(ids[i]) - items.append(serialize(embeddings[i])) - - self.db.execute_sql( - """ - INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) - VALUES {} - """.format(", ".join(["(?, ?)"] * len(ids))), - items, - ) + self.db.execute_sql( + """ + INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) + VALUES {} + """.format(", ".join(["(?, ?)"] * len(ids))), + items, + ) return embeddings @@ -261,10 +278,10 @@ class Embeddings: totals["processed_objects"] += 1 # run batch embedding - self.batch_upsert_thumbnail(batch_thumbs) + self.batch_embed_thumbnail(batch_thumbs) if batch_descs: - self.batch_upsert_description(batch_descs) + self.batch_embed_description(batch_descs) # report progress every batch so we don't spam the logs progress = (totals["processed_objects"] / total_events) * 100 diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py index 574822d59..6ea495a30 100644 --- a/frigate/embeddings/functions/onnx.py +++ b/frigate/embeddings/functions/onnx.py @@ -1,6 +1,7 @@ import logging import os import warnings +from enum import Enum from io import BytesIO from typing import Dict, List, Optional, Union @@ -31,6 +32,12 @@ disable_progress_bar() logger = logging.getLogger(__name__) +class ModelTypeEnum(str, Enum): + face = "face" + vision = "vision" + text = "text" + + class GenericONNXEmbedding: """Generic embedding function for ONNX models (text and vision).""" @@ -88,7 +95,10 @@ class GenericONNXEmbedding: file_name = os.path.basename(path) if file_name in self.download_urls: ModelDownloader.download_from_url(self.download_urls[file_name], path) - elif file_name == self.tokenizer_file and self.model_type == "text": + elif ( + file_name == self.tokenizer_file + and self.model_type == ModelTypeEnum.text + ): if not os.path.exists(path + "/" + self.model_name): logger.info(f"Downloading {self.model_name} tokenizer") tokenizer = AutoTokenizer.from_pretrained( @@ -119,7 +129,7 @@ class GenericONNXEmbedding: if self.runner is None: if self.downloader: self.downloader.wait_for_download() - if self.model_type == "text": + if self.model_type == ModelTypeEnum.text: self.tokenizer = self._load_tokenizer() else: self.feature_extractor = self._load_feature_extractor() @@ -143,11 +153,35 @@ class GenericONNXEmbedding: f"{MODEL_CACHE_DIR}/{self.model_name}", ) + def _preprocess_inputs(self, raw_inputs: any) -> any: + if self.model_type == ModelTypeEnum.text: + max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs) + return [ + self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=max_length, + return_tensors="np", + ) + for text in raw_inputs + ] + elif self.model_type == ModelTypeEnum.vision: + processed_images = [self._process_image(img) for img in raw_inputs] + return [ + self.feature_extractor(images=image, return_tensors="np") + for image in processed_images + ] + else: + raise ValueError(f"Unable to preprocess inputs for {self.model_type}") + def _process_image(self, image): if isinstance(image, str): if image.startswith("http"): response = requests.get(image) image = Image.open(BytesIO(response.content)).convert("RGB") + elif isinstance(image, bytes): + image = Image.open(BytesIO(image)).convert("RGB") return image @@ -163,25 +197,7 @@ class GenericONNXEmbedding: ) return [] - if self.model_type == "text": - max_length = max(len(self.tokenizer.encode(text)) for text in inputs) - processed_inputs = [ - self.tokenizer( - text, - padding="max_length", - truncation=True, - max_length=max_length, - return_tensors="np", - ) - for text in inputs - ] - else: - processed_images = [self._process_image(img) for img in inputs] - processed_inputs = [ - self.feature_extractor(images=image, return_tensors="np") - for image in processed_images - ] - + processed_inputs = self._preprocess_inputs(inputs) input_names = self.runner.get_input_names() onnx_inputs = {name: [] for name in input_names} input: dict[str, any] diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index 7ce63e7f8..1578a0fe3 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -86,7 +86,7 @@ class EmbeddingMaintainer(threading.Thread): try: if topic == EmbeddingsRequestEnum.embed_description.value: return serialize( - self.embeddings.upsert_description( + self.embeddings.embed_description( data["id"], data["description"] ), pack=False, @@ -94,7 +94,7 @@ class EmbeddingMaintainer(threading.Thread): elif topic == EmbeddingsRequestEnum.embed_thumbnail.value: thumbnail = base64.b64decode(data["thumbnail"]) return serialize( - self.embeddings.upsert_thumbnail(data["id"], thumbnail), + self.embeddings.embed_thumbnail(data["id"], thumbnail), pack=False, ) elif topic == EmbeddingsRequestEnum.generate_search.value: @@ -270,7 +270,7 @@ class EmbeddingMaintainer(threading.Thread): def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None: """Embed the thumbnail for an event.""" - self.embeddings.upsert_thumbnail(event_id, thumbnail) + self.embeddings.embed_thumbnail(event_id, thumbnail) def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None: """Embed the description for an event.""" @@ -290,8 +290,8 @@ class EmbeddingMaintainer(threading.Thread): {"id": event.id, "description": description}, ) - # Encode the description - self.embeddings.upsert_description(event.id, description) + # Embed the description + self.embeddings.embed_description(event.id, description) logger.debug( "Generated description for %s (%d images): %s", diff --git a/frigate/util/services.py b/frigate/util/services.py index 3f8ecf32c..7ff46f039 100644 --- a/frigate/util/services.py +++ b/frigate/util/services.py @@ -279,10 +279,27 @@ def get_intel_gpu_stats() -> dict[str, str]: logger.error(f"Unable to poll intel GPU stats: {p.stderr}") return None else: + output = "".join(p.stdout.split()) + try: - data = json.loads(f'[{"".join(p.stdout.split())}]') + data = json.loads(f"[{output}]") except json.JSONDecodeError: - return {"gpu": "-%", "mem": "-%"} + data = None + + # json is incomplete, remove characters until we get to valid json + while True: + while output and output[-1] != "}": + output = output[:-1] + + if not output: + return {"gpu": "", "mem": ""} + + try: + data = json.loads(f"[{output}]") + break + except json.JSONDecodeError: + output = output[:-1] + continue results: dict[str, str] = {} render = {"global": []}