mirror of
				https://github.com/blakeblackshear/frigate.git
				synced 2025-10-27 10:52:11 +01:00 
			
		
		
		
	Embeddings normalization fixes (#14284)
* Use cosine distance metric for vec tables * Only apply normalization to multi modal searches * Catch possible edge case in stddev calc * Use sigmoid function for normalization for multi modal searches only * Ensure we get model state on initial page load * Only save stats for multi modal searches and only use cosine similarity for image -> image search
This commit is contained in:
		
							parent
							
								
									d4b9b5a7dd
								
							
						
					
					
						commit
						8a8a0c7dec
					
				| @ -473,12 +473,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) | ||||
|             ) | ||||
| 
 | ||||
|         thumb_result = context.search_thumbnail(search_event) | ||||
|         thumb_ids = dict( | ||||
|             zip( | ||||
|                 [result[0] for result in thumb_result], | ||||
|                 context.thumb_stats.normalize([result[1] for result in thumb_result]), | ||||
|             ) | ||||
|         ) | ||||
|         thumb_ids = {result[0]: result[1] for result in thumb_result} | ||||
|         search_results = { | ||||
|             event_id: {"distance": distance, "source": "thumbnail"} | ||||
|             for event_id, distance in thumb_ids.items() | ||||
| @ -486,15 +481,18 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) | ||||
|     else: | ||||
|         search_types = search_type.split(",") | ||||
| 
 | ||||
|         # only save stats for multi-modal searches | ||||
|         save_stats = "thumbnail" in search_types and "description" in search_types | ||||
| 
 | ||||
|         if "thumbnail" in search_types: | ||||
|             thumb_result = context.search_thumbnail(query) | ||||
| 
 | ||||
|             thumb_distances = context.thumb_stats.normalize( | ||||
|                 [result[1] for result in thumb_result], save_stats | ||||
|             ) | ||||
| 
 | ||||
|             thumb_ids = dict( | ||||
|                 zip( | ||||
|                     [result[0] for result in thumb_result], | ||||
|                     context.thumb_stats.normalize( | ||||
|                         [result[1] for result in thumb_result] | ||||
|                     ), | ||||
|                 ) | ||||
|                 zip([result[0] for result in thumb_result], thumb_distances) | ||||
|             ) | ||||
|             search_results.update( | ||||
|                 { | ||||
| @ -505,12 +503,13 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) | ||||
| 
 | ||||
|         if "description" in search_types: | ||||
|             desc_result = context.search_description(query) | ||||
|             desc_ids = dict( | ||||
|                 zip( | ||||
|                     [result[0] for result in desc_result], | ||||
|                     context.desc_stats.normalize([result[1] for result in desc_result]), | ||||
|                 ) | ||||
| 
 | ||||
|             desc_distances = context.desc_stats.normalize( | ||||
|                 [result[1] for result in desc_result], save_stats | ||||
|             ) | ||||
| 
 | ||||
|             desc_ids = dict(zip([result[0] for result in desc_result], desc_distances)) | ||||
| 
 | ||||
|             for event_id, distance in desc_ids.items(): | ||||
|                 if ( | ||||
|                     event_id not in search_results | ||||
|  | ||||
| @ -42,12 +42,12 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase): | ||||
|         self.execute_sql(""" | ||||
|             CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0( | ||||
|                 id TEXT PRIMARY KEY, | ||||
|                 thumbnail_embedding FLOAT[768] | ||||
|                 thumbnail_embedding FLOAT[768] distance_metric=cosine | ||||
|             ); | ||||
|         """) | ||||
|         self.execute_sql(""" | ||||
|             CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0( | ||||
|                 id TEXT PRIMARY KEY, | ||||
|                 description_embedding FLOAT[768] | ||||
|                 description_embedding FLOAT[768] distance_metric=cosine | ||||
|             ); | ||||
|         """) | ||||
|  | ||||
| @ -20,10 +20,11 @@ class ZScoreNormalization: | ||||
| 
 | ||||
|     @property | ||||
|     def stddev(self): | ||||
|         return math.sqrt(self.variance) | ||||
|         return math.sqrt(self.variance) if self.variance > 0 else 0.0 | ||||
| 
 | ||||
|     def normalize(self, distances: list[float]): | ||||
|         self._update(distances) | ||||
|     def normalize(self, distances: list[float], save_stats: bool): | ||||
|         if save_stats: | ||||
|             self._update(distances) | ||||
|         if self.stddev == 0: | ||||
|             return distances | ||||
|         return [ | ||||
|  | ||||
| @ -2,6 +2,7 @@ import { | ||||
|   useEmbeddingsReindexProgress, | ||||
|   useEventUpdate, | ||||
|   useModelState, | ||||
|   useWs, | ||||
| } from "@/api/ws"; | ||||
| import ActivityIndicator from "@/components/indicators/activity-indicator"; | ||||
| import AnimatedCircularProgressBar from "@/components/ui/circular-progress-bar"; | ||||
| @ -202,6 +203,14 @@ export default function Explore() { | ||||
| 
 | ||||
|   // model states
 | ||||
| 
 | ||||
|   const { send: sendCommand } = useWs("model_state", "modelState"); | ||||
| 
 | ||||
|   useEffect(() => { | ||||
|     sendCommand("modelState"); | ||||
|     // only run on mount
 | ||||
|     // eslint-disable-next-line react-hooks/exhaustive-deps
 | ||||
|   }, []); | ||||
| 
 | ||||
|   const { payload: textModelState } = useModelState( | ||||
|     "jinaai/jina-clip-v1-text_model_fp16.onnx", | ||||
|   ); | ||||
|  | ||||
| @ -187,13 +187,19 @@ export default function SearchView({ | ||||
|     } | ||||
|   }, [searchResults, searchDetail]); | ||||
| 
 | ||||
|   // confidence score - probably needs tweaking
 | ||||
|   // confidence score
 | ||||
| 
 | ||||
|   const zScoreToConfidence = (score: number) => { | ||||
|     // Sigmoid function: 1 / (1 + e^x)
 | ||||
|     const confidence = 1 / (1 + Math.exp(score)); | ||||
|     // Normalizing is not needed for similarity searches
 | ||||
|     // Sigmoid function for normalized: 1 / (1 + e^x)
 | ||||
|     // Cosine for similarity
 | ||||
|     if (searchFilter) { | ||||
|       const notNormalized = searchFilter?.search_type?.includes("similarity"); | ||||
| 
 | ||||
|     return Math.round(confidence * 100); | ||||
|       const confidence = notNormalized ? 1 - score : 1 / (1 + Math.exp(score)); | ||||
| 
 | ||||
|       return Math.round(confidence * 100); | ||||
|     } | ||||
|   }; | ||||
| 
 | ||||
|   const hasExistingSearch = useMemo( | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user