From 1ec459ea3a0cf123e04514e059e9a5af0c51aac3 Mon Sep 17 00:00:00 2001 From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com> Date: Sun, 13 Oct 2024 16:25:13 -0500 Subject: [PATCH] Batch embeddings fixes (#14325) * fixes * more readable loops * more robust key check and warning message * ensure we get reindex progress on mount * use correct var for length --- frigate/embeddings/embeddings.py | 40 +++++++++++++++++----------- frigate/embeddings/functions/onnx.py | 16 ++++++++--- web/src/pages/Explore.tsx | 29 +++++++++++++++----- 3 files changed, 60 insertions(+), 25 deletions(-) diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index 8d12feb32..cb0626f7b 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -145,13 +145,18 @@ class Embeddings: ] ids = list(event_thumbs.keys()) embeddings = self.vision_embedding(images) - items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))] + + items = [] + + for i in range(len(ids)): + items.append(ids[i]) + items.append(serialize(embeddings[i])) self.db.execute_sql( """ INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) VALUES {} - """.format(", ".join(["(?, ?)"] * len(items))), + """.format(", ".join(["(?, ?)"] * len(ids))), items, ) return embeddings @@ -171,13 +176,18 @@ class Embeddings: def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray: embeddings = self.text_embedding(list(event_descriptions.values())) ids = list(event_descriptions.keys()) - items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))] + + items = [] + + for i in range(len(ids)): + items.append(ids[i]) + items.append(serialize(embeddings[i])) self.db.execute_sql( """ INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) VALUES {} - """.format(", ".join(["(?, ?)"] * len(items))), + """.format(", ".join(["(?, ?)"] * len(ids))), items, ) @@ -196,16 +206,6 @@ class Embeddings: os.remove(os.path.join(CONFIG_DIR, ".search_stats.json")) st = time.time() - totals = { - "thumbnails": 0, - "descriptions": 0, - "processed_objects": 0, - "total_objects": 0, - "time_remaining": 0, - "status": "indexing", - } - - self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals) # Get total count of events to process total_events = ( @@ -216,11 +216,21 @@ class Embeddings: ) .count() ) - totals["total_objects"] = total_events batch_size = 32 current_page = 1 + totals = { + "thumbnails": 0, + "descriptions": 0, + "processed_objects": total_events - 1 if total_events < batch_size else 0, + "total_objects": total_events, + "time_remaining": 0 if total_events < batch_size else -1, + "status": "indexing", + } + + self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals) + events = ( Event.select() .where( diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py index 765a7e88c..574822d59 100644 --- a/frigate/embeddings/functions/onnx.py +++ b/frigate/embeddings/functions/onnx.py @@ -164,8 +164,15 @@ class GenericONNXEmbedding: return [] if self.model_type == "text": + max_length = max(len(self.tokenizer.encode(text)) for text in inputs) processed_inputs = [ - self.tokenizer(text, padding=True, truncation=True, return_tensors="np") + self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=max_length, + return_tensors="np", + ) for text in inputs ] else: @@ -183,8 +190,11 @@ class GenericONNXEmbedding: if key in input_names: onnx_inputs[key].append(value[0]) - for key in onnx_inputs.keys(): - onnx_inputs[key] = np.array(onnx_inputs[key]) + for key in input_names: + if onnx_inputs.get(key): + onnx_inputs[key] = np.stack(onnx_inputs[key]) + else: + logger.warning(f"Expected input '{key}' not found in onnx_inputs") embeddings = self.runner.run(onnx_inputs)[0] return [embedding for embedding in embeddings] diff --git a/web/src/pages/Explore.tsx b/web/src/pages/Explore.tsx index 4aebaefd1..e4bb49521 100644 --- a/web/src/pages/Explore.tsx +++ b/web/src/pages/Explore.tsx @@ -193,6 +193,17 @@ export default function Explore() { // embeddings reindex progress + const { send: sendReindexCommand } = useWs( + "embeddings_reindex_progress", + "embeddingsReindexProgress", + ); + + useEffect(() => { + sendReindexCommand("embeddingsReindexProgress"); + // only run on mount + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + const { payload: reindexProgress } = useEmbeddingsReindexProgress(); const embeddingsReindexing = useMemo(() => { @@ -210,10 +221,10 @@ export default function Explore() { // model states - const { send: sendCommand } = useWs("model_state", "modelState"); + const { send: sendModelCommand } = useWs("model_state", "modelState"); useEffect(() => { - sendCommand("modelState"); + sendModelCommand("modelState"); // only run on mount // eslint-disable-next-line react-hooks/exhaustive-deps }, []); @@ -299,14 +310,18 @@ export default function Explore() { />