From e8b2fde753f50421dde2ef725b51ac6f5999ee5f Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Sun, 13 Oct 2024 11:33:27 -0600 Subject: [PATCH] Support batch embeddings when reindexing (#14320) * Refactor onnx embeddings to handle multiple inputs by default * Process items in batches when reindexing --- frigate/embeddings/embeddings.py | 97 ++++++++++++++++++---------- frigate/embeddings/functions/onnx.py | 35 +++++----- 2 files changed, 82 insertions(+), 50 deletions(-) diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index b5b166b00..8d12feb32 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -6,6 +6,7 @@ import logging import os import time +from numpy import ndarray from PIL import Image from playhouse.shortcuts import model_to_dict @@ -88,12 +89,6 @@ class Embeddings: }, ) - def jina_text_embedding_function(outputs): - return outputs[0] - - def jina_vision_embedding_function(outputs): - return outputs[0] - self.text_embedding = GenericONNXEmbedding( model_name="jinaai/jina-clip-v1", model_file="text_model_fp16.onnx", @@ -101,7 +96,6 @@ class Embeddings: download_urls={ "text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx", }, - embedding_function=jina_text_embedding_function, model_size=config.model_size, model_type="text", requestor=self.requestor, @@ -123,14 +117,13 @@ class Embeddings: model_name="jinaai/jina-clip-v1", model_file=model_file, download_urls=download_urls, - embedding_function=jina_vision_embedding_function, model_size=config.model_size, model_type="vision", requestor=self.requestor, device="GPU" if config.model_size == "large" else "CPU", ) - def upsert_thumbnail(self, event_id: str, thumbnail: bytes): + def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray: # Convert thumbnail bytes to PIL Image image = Image.open(io.BytesIO(thumbnail)).convert("RGB") embedding = self.vision_embedding([image])[0] @@ -145,7 +138,25 @@ class Embeddings: return embedding - def upsert_description(self, event_id: str, description: str): + def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]: + images = [ + Image.open(io.BytesIO(thumb)).convert("RGB") + for thumb in event_thumbs.values() + ] + ids = list(event_thumbs.keys()) + embeddings = self.vision_embedding(images) + items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))] + + self.db.execute_sql( + """ + INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) + VALUES {} + """.format(", ".join(["(?, ?)"] * len(items))), + items, + ) + return embeddings + + def upsert_description(self, event_id: str, description: str) -> ndarray: embedding = self.text_embedding([description])[0] self.db.execute_sql( """ @@ -157,6 +168,21 @@ class Embeddings: return embedding + def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray: + embeddings = self.text_embedding(list(event_descriptions.values())) + ids = list(event_descriptions.keys()) + items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))] + + self.db.execute_sql( + """ + INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) + VALUES {} + """.format(", ".join(["(?, ?)"] * len(items))), + items, + ) + + return embeddings + def reindex(self) -> None: logger.info("Indexing tracked object embeddings...") @@ -192,9 +218,8 @@ class Embeddings: ) totals["total_objects"] = total_events - batch_size = 100 + batch_size = 32 current_page = 1 - processed_events = 0 events = ( Event.select() @@ -208,37 +233,43 @@ class Embeddings: while len(events) > 0: event: Event + batch_thumbs = {} + batch_descs = {} for event in events: - thumbnail = base64.b64decode(event.thumbnail) - self.upsert_thumbnail(event.id, thumbnail) + batch_thumbs[event.id] = base64.b64decode(event.thumbnail) totals["thumbnails"] += 1 if description := event.data.get("description", "").strip(): + batch_descs[event.id] = description totals["descriptions"] += 1 - self.upsert_description(event.id, description) totals["processed_objects"] += 1 - # report progress every 10 events so we don't spam the logs - if (totals["processed_objects"] % 10) == 0: - progress = (processed_events / total_events) * 100 - logger.debug( - "Processed %d/%d events (%.2f%% complete) | Thumbnails: %d, Descriptions: %d", - processed_events, - total_events, - progress, - totals["thumbnails"], - totals["descriptions"], - ) + # run batch embedding + self.batch_upsert_thumbnail(batch_thumbs) - # Calculate time remaining - elapsed_time = time.time() - st - avg_time_per_event = elapsed_time / totals["processed_objects"] - remaining_events = total_events - totals["processed_objects"] - time_remaining = avg_time_per_event * remaining_events - totals["time_remaining"] = int(time_remaining) + if batch_descs: + self.batch_upsert_description(batch_descs) - self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals) + # report progress every batch so we don't spam the logs + progress = (totals["processed_objects"] / total_events) * 100 + logger.debug( + "Processed %d/%d events (%.2f%% complete) | Thumbnails: %d, Descriptions: %d", + totals["processed_objects"], + total_events, + progress, + totals["thumbnails"], + totals["descriptions"], + ) + + # Calculate time remaining + elapsed_time = time.time() - st + avg_time_per_event = elapsed_time / totals["processed_objects"] + remaining_events = total_events - totals["processed_objects"] + time_remaining = avg_time_per_event * remaining_events + totals["time_remaining"] = int(time_remaining) + + self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals) # Move to the next page current_page += 1 diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py index e836ba960..765a7e88c 100644 --- a/frigate/embeddings/functions/onnx.py +++ b/frigate/embeddings/functions/onnx.py @@ -2,7 +2,7 @@ import logging import os import warnings from io import BytesIO -from typing import Callable, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np import requests @@ -39,7 +39,6 @@ class GenericONNXEmbedding: model_name: str, model_file: str, download_urls: Dict[str, str], - embedding_function: Callable[[List[np.ndarray]], np.ndarray], model_size: str, model_type: str, requestor: InterProcessRequestor, @@ -51,7 +50,6 @@ class GenericONNXEmbedding: self.tokenizer_file = tokenizer_file self.requestor = requestor self.download_urls = download_urls - self.embedding_function = embedding_function self.model_type = model_type # 'text' or 'vision' self.model_size = model_size self.device = device @@ -157,7 +155,6 @@ class GenericONNXEmbedding: self, inputs: Union[List[str], List[Image.Image], List[str]] ) -> List[np.ndarray]: self._load_model_and_tokenizer() - if self.runner is None or ( self.tokenizer is None and self.feature_extractor is None ): @@ -167,23 +164,27 @@ class GenericONNXEmbedding: return [] if self.model_type == "text": - processed_inputs = self.tokenizer( - inputs, padding=True, truncation=True, return_tensors="np" - ) + processed_inputs = [ + self.tokenizer(text, padding=True, truncation=True, return_tensors="np") + for text in inputs + ] else: processed_images = [self._process_image(img) for img in inputs] - processed_inputs = self.feature_extractor( - images=processed_images, return_tensors="np" - ) + processed_inputs = [ + self.feature_extractor(images=image, return_tensors="np") + for image in processed_images + ] input_names = self.runner.get_input_names() - onnx_inputs = { - name: processed_inputs[name] - for name in input_names - if name in processed_inputs - } + onnx_inputs = {name: [] for name in input_names} + input: dict[str, any] + for input in processed_inputs: + for key, value in input.items(): + if key in input_names: + onnx_inputs[key].append(value[0]) - outputs = self.runner.run(onnx_inputs) - embeddings = self.embedding_function(outputs) + for key in onnx_inputs.keys(): + onnx_inputs[key] = np.array(onnx_inputs[key]) + embeddings = self.runner.run(onnx_inputs)[0] return [embedding for embedding in embeddings]