mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-21 19:07:46 +01:00
Embeddings normalization fixes (#14284)
* Use cosine distance metric for vec tables * Only apply normalization to multi modal searches * Catch possible edge case in stddev calc * Use sigmoid function for normalization for multi modal searches only * Ensure we get model state on initial page load * Only save stats for multi modal searches and only use cosine similarity for image -> image search
This commit is contained in:
parent
d4b9b5a7dd
commit
8a8a0c7dec
@ -473,12 +473,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
|||||||
)
|
)
|
||||||
|
|
||||||
thumb_result = context.search_thumbnail(search_event)
|
thumb_result = context.search_thumbnail(search_event)
|
||||||
thumb_ids = dict(
|
thumb_ids = {result[0]: result[1] for result in thumb_result}
|
||||||
zip(
|
|
||||||
[result[0] for result in thumb_result],
|
|
||||||
context.thumb_stats.normalize([result[1] for result in thumb_result]),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
search_results = {
|
search_results = {
|
||||||
event_id: {"distance": distance, "source": "thumbnail"}
|
event_id: {"distance": distance, "source": "thumbnail"}
|
||||||
for event_id, distance in thumb_ids.items()
|
for event_id, distance in thumb_ids.items()
|
||||||
@ -486,15 +481,18 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
|||||||
else:
|
else:
|
||||||
search_types = search_type.split(",")
|
search_types = search_type.split(",")
|
||||||
|
|
||||||
|
# only save stats for multi-modal searches
|
||||||
|
save_stats = "thumbnail" in search_types and "description" in search_types
|
||||||
|
|
||||||
if "thumbnail" in search_types:
|
if "thumbnail" in search_types:
|
||||||
thumb_result = context.search_thumbnail(query)
|
thumb_result = context.search_thumbnail(query)
|
||||||
thumb_ids = dict(
|
|
||||||
zip(
|
thumb_distances = context.thumb_stats.normalize(
|
||||||
[result[0] for result in thumb_result],
|
[result[1] for result in thumb_result], save_stats
|
||||||
context.thumb_stats.normalize(
|
|
||||||
[result[1] for result in thumb_result]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
thumb_ids = dict(
|
||||||
|
zip([result[0] for result in thumb_result], thumb_distances)
|
||||||
)
|
)
|
||||||
search_results.update(
|
search_results.update(
|
||||||
{
|
{
|
||||||
@ -505,12 +503,13 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
|||||||
|
|
||||||
if "description" in search_types:
|
if "description" in search_types:
|
||||||
desc_result = context.search_description(query)
|
desc_result = context.search_description(query)
|
||||||
desc_ids = dict(
|
|
||||||
zip(
|
desc_distances = context.desc_stats.normalize(
|
||||||
[result[0] for result in desc_result],
|
[result[1] for result in desc_result], save_stats
|
||||||
context.desc_stats.normalize([result[1] for result in desc_result]),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
desc_ids = dict(zip([result[0] for result in desc_result], desc_distances))
|
||||||
|
|
||||||
for event_id, distance in desc_ids.items():
|
for event_id, distance in desc_ids.items():
|
||||||
if (
|
if (
|
||||||
event_id not in search_results
|
event_id not in search_results
|
||||||
|
@ -42,12 +42,12 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
|
|||||||
self.execute_sql("""
|
self.execute_sql("""
|
||||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
|
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
thumbnail_embedding FLOAT[768]
|
thumbnail_embedding FLOAT[768] distance_metric=cosine
|
||||||
);
|
);
|
||||||
""")
|
""")
|
||||||
self.execute_sql("""
|
self.execute_sql("""
|
||||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
|
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
description_embedding FLOAT[768]
|
description_embedding FLOAT[768] distance_metric=cosine
|
||||||
);
|
);
|
||||||
""")
|
""")
|
||||||
|
@ -20,9 +20,10 @@ class ZScoreNormalization:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def stddev(self):
|
def stddev(self):
|
||||||
return math.sqrt(self.variance)
|
return math.sqrt(self.variance) if self.variance > 0 else 0.0
|
||||||
|
|
||||||
def normalize(self, distances: list[float]):
|
def normalize(self, distances: list[float], save_stats: bool):
|
||||||
|
if save_stats:
|
||||||
self._update(distances)
|
self._update(distances)
|
||||||
if self.stddev == 0:
|
if self.stddev == 0:
|
||||||
return distances
|
return distances
|
||||||
|
@ -2,6 +2,7 @@ import {
|
|||||||
useEmbeddingsReindexProgress,
|
useEmbeddingsReindexProgress,
|
||||||
useEventUpdate,
|
useEventUpdate,
|
||||||
useModelState,
|
useModelState,
|
||||||
|
useWs,
|
||||||
} from "@/api/ws";
|
} from "@/api/ws";
|
||||||
import ActivityIndicator from "@/components/indicators/activity-indicator";
|
import ActivityIndicator from "@/components/indicators/activity-indicator";
|
||||||
import AnimatedCircularProgressBar from "@/components/ui/circular-progress-bar";
|
import AnimatedCircularProgressBar from "@/components/ui/circular-progress-bar";
|
||||||
@ -202,6 +203,14 @@ export default function Explore() {
|
|||||||
|
|
||||||
// model states
|
// model states
|
||||||
|
|
||||||
|
const { send: sendCommand } = useWs("model_state", "modelState");
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
sendCommand("modelState");
|
||||||
|
// only run on mount
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, []);
|
||||||
|
|
||||||
const { payload: textModelState } = useModelState(
|
const { payload: textModelState } = useModelState(
|
||||||
"jinaai/jina-clip-v1-text_model_fp16.onnx",
|
"jinaai/jina-clip-v1-text_model_fp16.onnx",
|
||||||
);
|
);
|
||||||
|
@ -187,13 +187,19 @@ export default function SearchView({
|
|||||||
}
|
}
|
||||||
}, [searchResults, searchDetail]);
|
}, [searchResults, searchDetail]);
|
||||||
|
|
||||||
// confidence score - probably needs tweaking
|
// confidence score
|
||||||
|
|
||||||
const zScoreToConfidence = (score: number) => {
|
const zScoreToConfidence = (score: number) => {
|
||||||
// Sigmoid function: 1 / (1 + e^x)
|
// Normalizing is not needed for similarity searches
|
||||||
const confidence = 1 / (1 + Math.exp(score));
|
// Sigmoid function for normalized: 1 / (1 + e^x)
|
||||||
|
// Cosine for similarity
|
||||||
|
if (searchFilter) {
|
||||||
|
const notNormalized = searchFilter?.search_type?.includes("similarity");
|
||||||
|
|
||||||
|
const confidence = notNormalized ? 1 - score : 1 / (1 + Math.exp(score));
|
||||||
|
|
||||||
return Math.round(confidence * 100);
|
return Math.round(confidence * 100);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const hasExistingSearch = useMemo(
|
const hasExistingSearch = useMemo(
|
||||||
|
Loading…
Reference in New Issue
Block a user