From 8a8a0c7decfb0aa70cdc87d8a7de3263d4bec265 Mon Sep 17 00:00:00 2001 From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com> Date: Fri, 11 Oct 2024 13:11:11 -0500 Subject: [PATCH] Embeddings normalization fixes (#14284) * Use cosine distance metric for vec tables * Only apply normalization to multi modal searches * Catch possible edge case in stddev calc * Use sigmoid function for normalization for multi modal searches only * Ensure we get model state on initial page load * Only save stats for multi modal searches and only use cosine similarity for image -> image search --- frigate/api/event.py | 33 ++++++++++++++--------------- frigate/db/sqlitevecq.py | 4 ++-- frigate/embeddings/util.py | 7 +++--- web/src/pages/Explore.tsx | 9 ++++++++ web/src/views/search/SearchView.tsx | 14 ++++++++---- 5 files changed, 41 insertions(+), 26 deletions(-) diff --git a/frigate/api/event.py b/frigate/api/event.py index 3a8d003ad..c716bba13 100644 --- a/frigate/api/event.py +++ b/frigate/api/event.py @@ -473,12 +473,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) ) thumb_result = context.search_thumbnail(search_event) - thumb_ids = dict( - zip( - [result[0] for result in thumb_result], - context.thumb_stats.normalize([result[1] for result in thumb_result]), - ) - ) + thumb_ids = {result[0]: result[1] for result in thumb_result} search_results = { event_id: {"distance": distance, "source": "thumbnail"} for event_id, distance in thumb_ids.items() @@ -486,15 +481,18 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) else: search_types = search_type.split(",") + # only save stats for multi-modal searches + save_stats = "thumbnail" in search_types and "description" in search_types + if "thumbnail" in search_types: thumb_result = context.search_thumbnail(query) + + thumb_distances = context.thumb_stats.normalize( + [result[1] for result in thumb_result], save_stats + ) + thumb_ids = dict( - zip( - [result[0] for result in thumb_result], - context.thumb_stats.normalize( - [result[1] for result in thumb_result] - ), - ) + zip([result[0] for result in thumb_result], thumb_distances) ) search_results.update( { @@ -505,12 +503,13 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) if "description" in search_types: desc_result = context.search_description(query) - desc_ids = dict( - zip( - [result[0] for result in desc_result], - context.desc_stats.normalize([result[1] for result in desc_result]), - ) + + desc_distances = context.desc_stats.normalize( + [result[1] for result in desc_result], save_stats ) + + desc_ids = dict(zip([result[0] for result in desc_result], desc_distances)) + for event_id, distance in desc_ids.items(): if ( event_id not in search_results diff --git a/frigate/db/sqlitevecq.py b/frigate/db/sqlitevecq.py index 398adbd2d..ccb75ae54 100644 --- a/frigate/db/sqlitevecq.py +++ b/frigate/db/sqlitevecq.py @@ -42,12 +42,12 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase): self.execute_sql(""" CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0( id TEXT PRIMARY KEY, - thumbnail_embedding FLOAT[768] + thumbnail_embedding FLOAT[768] distance_metric=cosine ); """) self.execute_sql(""" CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0( id TEXT PRIMARY KEY, - description_embedding FLOAT[768] + description_embedding FLOAT[768] distance_metric=cosine ); """) diff --git a/frigate/embeddings/util.py b/frigate/embeddings/util.py index 0b2acd4d6..bc1a952ec 100644 --- a/frigate/embeddings/util.py +++ b/frigate/embeddings/util.py @@ -20,10 +20,11 @@ class ZScoreNormalization: @property def stddev(self): - return math.sqrt(self.variance) + return math.sqrt(self.variance) if self.variance > 0 else 0.0 - def normalize(self, distances: list[float]): - self._update(distances) + def normalize(self, distances: list[float], save_stats: bool): + if save_stats: + self._update(distances) if self.stddev == 0: return distances return [ diff --git a/web/src/pages/Explore.tsx b/web/src/pages/Explore.tsx index a3d7d3085..d3c5f7d9b 100644 --- a/web/src/pages/Explore.tsx +++ b/web/src/pages/Explore.tsx @@ -2,6 +2,7 @@ import { useEmbeddingsReindexProgress, useEventUpdate, useModelState, + useWs, } from "@/api/ws"; import ActivityIndicator from "@/components/indicators/activity-indicator"; import AnimatedCircularProgressBar from "@/components/ui/circular-progress-bar"; @@ -202,6 +203,14 @@ export default function Explore() { // model states + const { send: sendCommand } = useWs("model_state", "modelState"); + + useEffect(() => { + sendCommand("modelState"); + // only run on mount + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + const { payload: textModelState } = useModelState( "jinaai/jina-clip-v1-text_model_fp16.onnx", ); diff --git a/web/src/views/search/SearchView.tsx b/web/src/views/search/SearchView.tsx index 203942083..4c33f7dc8 100644 --- a/web/src/views/search/SearchView.tsx +++ b/web/src/views/search/SearchView.tsx @@ -187,13 +187,19 @@ export default function SearchView({ } }, [searchResults, searchDetail]); - // confidence score - probably needs tweaking + // confidence score const zScoreToConfidence = (score: number) => { - // Sigmoid function: 1 / (1 + e^x) - const confidence = 1 / (1 + Math.exp(score)); + // Normalizing is not needed for similarity searches + // Sigmoid function for normalized: 1 / (1 + e^x) + // Cosine for similarity + if (searchFilter) { + const notNormalized = searchFilter?.search_type?.includes("similarity"); - return Math.round(confidence * 100); + const confidence = notNormalized ? 1 - score : 1 / (1 + Math.exp(score)); + + return Math.round(confidence * 100); + } }; const hasExistingSearch = useMemo(