mirror of
				https://github.com/blakeblackshear/frigate.git
				synced 2025-10-27 10:52:11 +01:00 
			
		
		
		
	Support batch embeddings when reindexing (#14320)
* Refactor onnx embeddings to handle multiple inputs by default * Process items in batches when reindexing
This commit is contained in:
		
							parent
							
								
									0fc7999780
								
							
						
					
					
						commit
						e8b2fde753
					
				| @ -6,6 +6,7 @@ import logging | ||||
| import os | ||||
| import time | ||||
| 
 | ||||
| from numpy import ndarray | ||||
| from PIL import Image | ||||
| from playhouse.shortcuts import model_to_dict | ||||
| 
 | ||||
| @ -88,12 +89,6 @@ class Embeddings: | ||||
|                 }, | ||||
|             ) | ||||
| 
 | ||||
|         def jina_text_embedding_function(outputs): | ||||
|             return outputs[0] | ||||
| 
 | ||||
|         def jina_vision_embedding_function(outputs): | ||||
|             return outputs[0] | ||||
| 
 | ||||
|         self.text_embedding = GenericONNXEmbedding( | ||||
|             model_name="jinaai/jina-clip-v1", | ||||
|             model_file="text_model_fp16.onnx", | ||||
| @ -101,7 +96,6 @@ class Embeddings: | ||||
|             download_urls={ | ||||
|                 "text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx", | ||||
|             }, | ||||
|             embedding_function=jina_text_embedding_function, | ||||
|             model_size=config.model_size, | ||||
|             model_type="text", | ||||
|             requestor=self.requestor, | ||||
| @ -123,14 +117,13 @@ class Embeddings: | ||||
|             model_name="jinaai/jina-clip-v1", | ||||
|             model_file=model_file, | ||||
|             download_urls=download_urls, | ||||
|             embedding_function=jina_vision_embedding_function, | ||||
|             model_size=config.model_size, | ||||
|             model_type="vision", | ||||
|             requestor=self.requestor, | ||||
|             device="GPU" if config.model_size == "large" else "CPU", | ||||
|         ) | ||||
| 
 | ||||
|     def upsert_thumbnail(self, event_id: str, thumbnail: bytes): | ||||
|     def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray: | ||||
|         # Convert thumbnail bytes to PIL Image | ||||
|         image = Image.open(io.BytesIO(thumbnail)).convert("RGB") | ||||
|         embedding = self.vision_embedding([image])[0] | ||||
| @ -145,7 +138,25 @@ class Embeddings: | ||||
| 
 | ||||
|         return embedding | ||||
| 
 | ||||
|     def upsert_description(self, event_id: str, description: str): | ||||
|     def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]: | ||||
|         images = [ | ||||
|             Image.open(io.BytesIO(thumb)).convert("RGB") | ||||
|             for thumb in event_thumbs.values() | ||||
|         ] | ||||
|         ids = list(event_thumbs.keys()) | ||||
|         embeddings = self.vision_embedding(images) | ||||
|         items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))] | ||||
| 
 | ||||
|         self.db.execute_sql( | ||||
|             """ | ||||
|             INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) | ||||
|             VALUES {} | ||||
|             """.format(", ".join(["(?, ?)"] * len(items))), | ||||
|             items, | ||||
|         ) | ||||
|         return embeddings | ||||
| 
 | ||||
|     def upsert_description(self, event_id: str, description: str) -> ndarray: | ||||
|         embedding = self.text_embedding([description])[0] | ||||
|         self.db.execute_sql( | ||||
|             """ | ||||
| @ -157,6 +168,21 @@ class Embeddings: | ||||
| 
 | ||||
|         return embedding | ||||
| 
 | ||||
|     def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray: | ||||
|         embeddings = self.text_embedding(list(event_descriptions.values())) | ||||
|         ids = list(event_descriptions.keys()) | ||||
|         items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))] | ||||
| 
 | ||||
|         self.db.execute_sql( | ||||
|             """ | ||||
|             INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) | ||||
|             VALUES {} | ||||
|             """.format(", ".join(["(?, ?)"] * len(items))), | ||||
|             items, | ||||
|         ) | ||||
| 
 | ||||
|         return embeddings | ||||
| 
 | ||||
|     def reindex(self) -> None: | ||||
|         logger.info("Indexing tracked object embeddings...") | ||||
| 
 | ||||
| @ -192,9 +218,8 @@ class Embeddings: | ||||
|         ) | ||||
|         totals["total_objects"] = total_events | ||||
| 
 | ||||
|         batch_size = 100 | ||||
|         batch_size = 32 | ||||
|         current_page = 1 | ||||
|         processed_events = 0 | ||||
| 
 | ||||
|         events = ( | ||||
|             Event.select() | ||||
| @ -208,23 +233,29 @@ class Embeddings: | ||||
| 
 | ||||
|         while len(events) > 0: | ||||
|             event: Event | ||||
|             batch_thumbs = {} | ||||
|             batch_descs = {} | ||||
|             for event in events: | ||||
|                 thumbnail = base64.b64decode(event.thumbnail) | ||||
|                 self.upsert_thumbnail(event.id, thumbnail) | ||||
|                 batch_thumbs[event.id] = base64.b64decode(event.thumbnail) | ||||
|                 totals["thumbnails"] += 1 | ||||
| 
 | ||||
|                 if description := event.data.get("description", "").strip(): | ||||
|                     batch_descs[event.id] = description | ||||
|                     totals["descriptions"] += 1 | ||||
|                     self.upsert_description(event.id, description) | ||||
| 
 | ||||
|                 totals["processed_objects"] += 1 | ||||
| 
 | ||||
|                 # report progress every 10 events so we don't spam the logs | ||||
|                 if (totals["processed_objects"] % 10) == 0: | ||||
|                     progress = (processed_events / total_events) * 100 | ||||
|             # run batch embedding | ||||
|             self.batch_upsert_thumbnail(batch_thumbs) | ||||
| 
 | ||||
|             if batch_descs: | ||||
|                 self.batch_upsert_description(batch_descs) | ||||
| 
 | ||||
|             # report progress every batch so we don't spam the logs | ||||
|             progress = (totals["processed_objects"] / total_events) * 100 | ||||
|             logger.debug( | ||||
|                 "Processed %d/%d events (%.2f%% complete) | Thumbnails: %d, Descriptions: %d", | ||||
|                         processed_events, | ||||
|                 totals["processed_objects"], | ||||
|                 total_events, | ||||
|                 progress, | ||||
|                 totals["thumbnails"], | ||||
|  | ||||
| @ -2,7 +2,7 @@ import logging | ||||
| import os | ||||
| import warnings | ||||
| from io import BytesIO | ||||
| from typing import Callable, Dict, List, Optional, Union | ||||
| from typing import Dict, List, Optional, Union | ||||
| 
 | ||||
| import numpy as np | ||||
| import requests | ||||
| @ -39,7 +39,6 @@ class GenericONNXEmbedding: | ||||
|         model_name: str, | ||||
|         model_file: str, | ||||
|         download_urls: Dict[str, str], | ||||
|         embedding_function: Callable[[List[np.ndarray]], np.ndarray], | ||||
|         model_size: str, | ||||
|         model_type: str, | ||||
|         requestor: InterProcessRequestor, | ||||
| @ -51,7 +50,6 @@ class GenericONNXEmbedding: | ||||
|         self.tokenizer_file = tokenizer_file | ||||
|         self.requestor = requestor | ||||
|         self.download_urls = download_urls | ||||
|         self.embedding_function = embedding_function | ||||
|         self.model_type = model_type  # 'text' or 'vision' | ||||
|         self.model_size = model_size | ||||
|         self.device = device | ||||
| @ -157,7 +155,6 @@ class GenericONNXEmbedding: | ||||
|         self, inputs: Union[List[str], List[Image.Image], List[str]] | ||||
|     ) -> List[np.ndarray]: | ||||
|         self._load_model_and_tokenizer() | ||||
| 
 | ||||
|         if self.runner is None or ( | ||||
|             self.tokenizer is None and self.feature_extractor is None | ||||
|         ): | ||||
| @ -167,23 +164,27 @@ class GenericONNXEmbedding: | ||||
|             return [] | ||||
| 
 | ||||
|         if self.model_type == "text": | ||||
|             processed_inputs = self.tokenizer( | ||||
|                 inputs, padding=True, truncation=True, return_tensors="np" | ||||
|             ) | ||||
|             processed_inputs = [ | ||||
|                 self.tokenizer(text, padding=True, truncation=True, return_tensors="np") | ||||
|                 for text in inputs | ||||
|             ] | ||||
|         else: | ||||
|             processed_images = [self._process_image(img) for img in inputs] | ||||
|             processed_inputs = self.feature_extractor( | ||||
|                 images=processed_images, return_tensors="np" | ||||
|             ) | ||||
|             processed_inputs = [ | ||||
|                 self.feature_extractor(images=image, return_tensors="np") | ||||
|                 for image in processed_images | ||||
|             ] | ||||
| 
 | ||||
|         input_names = self.runner.get_input_names() | ||||
|         onnx_inputs = { | ||||
|             name: processed_inputs[name] | ||||
|             for name in input_names | ||||
|             if name in processed_inputs | ||||
|         } | ||||
|         onnx_inputs = {name: [] for name in input_names} | ||||
|         input: dict[str, any] | ||||
|         for input in processed_inputs: | ||||
|             for key, value in input.items(): | ||||
|                 if key in input_names: | ||||
|                     onnx_inputs[key].append(value[0]) | ||||
| 
 | ||||
|         outputs = self.runner.run(onnx_inputs) | ||||
|         embeddings = self.embedding_function(outputs) | ||||
|         for key in onnx_inputs.keys(): | ||||
|             onnx_inputs[key] = np.array(onnx_inputs[key]) | ||||
| 
 | ||||
|         embeddings = self.runner.run(onnx_inputs)[0] | ||||
|         return [embedding for embedding in embeddings] | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user