Embeddings normalization fixes (#14284)

* Use cosine distance metric for vec tables

* Only apply normalization to multi modal searches

* Catch possible edge case in stddev calc

* Use sigmoid function for normalization for multi modal searches only

* Ensure we get model state on initial page load

* Only save stats for multi modal searches and only use cosine similarity for image -> image search
This commit is contained in:
Josh Hawkins 2024-10-11 13:11:11 -05:00 committed by GitHub
parent d4b9b5a7dd
commit 8a8a0c7dec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 41 additions and 26 deletions

View File

@ -473,12 +473,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
) )
thumb_result = context.search_thumbnail(search_event) thumb_result = context.search_thumbnail(search_event)
thumb_ids = dict( thumb_ids = {result[0]: result[1] for result in thumb_result}
zip(
[result[0] for result in thumb_result],
context.thumb_stats.normalize([result[1] for result in thumb_result]),
)
)
search_results = { search_results = {
event_id: {"distance": distance, "source": "thumbnail"} event_id: {"distance": distance, "source": "thumbnail"}
for event_id, distance in thumb_ids.items() for event_id, distance in thumb_ids.items()
@ -486,15 +481,18 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
else: else:
search_types = search_type.split(",") search_types = search_type.split(",")
# only save stats for multi-modal searches
save_stats = "thumbnail" in search_types and "description" in search_types
if "thumbnail" in search_types: if "thumbnail" in search_types:
thumb_result = context.search_thumbnail(query) thumb_result = context.search_thumbnail(query)
thumb_distances = context.thumb_stats.normalize(
[result[1] for result in thumb_result], save_stats
)
thumb_ids = dict( thumb_ids = dict(
zip( zip([result[0] for result in thumb_result], thumb_distances)
[result[0] for result in thumb_result],
context.thumb_stats.normalize(
[result[1] for result in thumb_result]
),
)
) )
search_results.update( search_results.update(
{ {
@ -505,12 +503,13 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
if "description" in search_types: if "description" in search_types:
desc_result = context.search_description(query) desc_result = context.search_description(query)
desc_ids = dict(
zip( desc_distances = context.desc_stats.normalize(
[result[0] for result in desc_result], [result[1] for result in desc_result], save_stats
context.desc_stats.normalize([result[1] for result in desc_result]),
)
) )
desc_ids = dict(zip([result[0] for result in desc_result], desc_distances))
for event_id, distance in desc_ids.items(): for event_id, distance in desc_ids.items():
if ( if (
event_id not in search_results event_id not in search_results

View File

@ -42,12 +42,12 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
self.execute_sql(""" self.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0( CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
thumbnail_embedding FLOAT[768] thumbnail_embedding FLOAT[768] distance_metric=cosine
); );
""") """)
self.execute_sql(""" self.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0( CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
description_embedding FLOAT[768] description_embedding FLOAT[768] distance_metric=cosine
); );
""") """)

View File

@ -20,10 +20,11 @@ class ZScoreNormalization:
@property @property
def stddev(self): def stddev(self):
return math.sqrt(self.variance) return math.sqrt(self.variance) if self.variance > 0 else 0.0
def normalize(self, distances: list[float]): def normalize(self, distances: list[float], save_stats: bool):
self._update(distances) if save_stats:
self._update(distances)
if self.stddev == 0: if self.stddev == 0:
return distances return distances
return [ return [

View File

@ -2,6 +2,7 @@ import {
useEmbeddingsReindexProgress, useEmbeddingsReindexProgress,
useEventUpdate, useEventUpdate,
useModelState, useModelState,
useWs,
} from "@/api/ws"; } from "@/api/ws";
import ActivityIndicator from "@/components/indicators/activity-indicator"; import ActivityIndicator from "@/components/indicators/activity-indicator";
import AnimatedCircularProgressBar from "@/components/ui/circular-progress-bar"; import AnimatedCircularProgressBar from "@/components/ui/circular-progress-bar";
@ -202,6 +203,14 @@ export default function Explore() {
// model states // model states
const { send: sendCommand } = useWs("model_state", "modelState");
useEffect(() => {
sendCommand("modelState");
// only run on mount
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const { payload: textModelState } = useModelState( const { payload: textModelState } = useModelState(
"jinaai/jina-clip-v1-text_model_fp16.onnx", "jinaai/jina-clip-v1-text_model_fp16.onnx",
); );

View File

@ -187,13 +187,19 @@ export default function SearchView({
} }
}, [searchResults, searchDetail]); }, [searchResults, searchDetail]);
// confidence score - probably needs tweaking // confidence score
const zScoreToConfidence = (score: number) => { const zScoreToConfidence = (score: number) => {
// Sigmoid function: 1 / (1 + e^x) // Normalizing is not needed for similarity searches
const confidence = 1 / (1 + Math.exp(score)); // Sigmoid function for normalized: 1 / (1 + e^x)
// Cosine for similarity
if (searchFilter) {
const notNormalized = searchFilter?.search_type?.includes("similarity");
return Math.round(confidence * 100); const confidence = notNormalized ? 1 - score : 1 / (1 + Math.exp(score));
return Math.round(confidence * 100);
}
}; };
const hasExistingSearch = useMemo( const hasExistingSearch = useMemo(