mirror of
				https://github.com/blakeblackshear/frigate.git
				synced 2025-10-27 10:52:11 +01:00 
			
		
		
		
	Support batch embeddings when reindexing (#14320)
* Refactor onnx embeddings to handle multiple inputs by default * Process items in batches when reindexing
This commit is contained in:
		
							parent
							
								
									0fc7999780
								
							
						
					
					
						commit
						e8b2fde753
					
				| @ -6,6 +6,7 @@ import logging | |||||||
| import os | import os | ||||||
| import time | import time | ||||||
| 
 | 
 | ||||||
|  | from numpy import ndarray | ||||||
| from PIL import Image | from PIL import Image | ||||||
| from playhouse.shortcuts import model_to_dict | from playhouse.shortcuts import model_to_dict | ||||||
| 
 | 
 | ||||||
| @ -88,12 +89,6 @@ class Embeddings: | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         def jina_text_embedding_function(outputs): |  | ||||||
|             return outputs[0] |  | ||||||
| 
 |  | ||||||
|         def jina_vision_embedding_function(outputs): |  | ||||||
|             return outputs[0] |  | ||||||
| 
 |  | ||||||
|         self.text_embedding = GenericONNXEmbedding( |         self.text_embedding = GenericONNXEmbedding( | ||||||
|             model_name="jinaai/jina-clip-v1", |             model_name="jinaai/jina-clip-v1", | ||||||
|             model_file="text_model_fp16.onnx", |             model_file="text_model_fp16.onnx", | ||||||
| @ -101,7 +96,6 @@ class Embeddings: | |||||||
|             download_urls={ |             download_urls={ | ||||||
|                 "text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx", |                 "text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx", | ||||||
|             }, |             }, | ||||||
|             embedding_function=jina_text_embedding_function, |  | ||||||
|             model_size=config.model_size, |             model_size=config.model_size, | ||||||
|             model_type="text", |             model_type="text", | ||||||
|             requestor=self.requestor, |             requestor=self.requestor, | ||||||
| @ -123,14 +117,13 @@ class Embeddings: | |||||||
|             model_name="jinaai/jina-clip-v1", |             model_name="jinaai/jina-clip-v1", | ||||||
|             model_file=model_file, |             model_file=model_file, | ||||||
|             download_urls=download_urls, |             download_urls=download_urls, | ||||||
|             embedding_function=jina_vision_embedding_function, |  | ||||||
|             model_size=config.model_size, |             model_size=config.model_size, | ||||||
|             model_type="vision", |             model_type="vision", | ||||||
|             requestor=self.requestor, |             requestor=self.requestor, | ||||||
|             device="GPU" if config.model_size == "large" else "CPU", |             device="GPU" if config.model_size == "large" else "CPU", | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def upsert_thumbnail(self, event_id: str, thumbnail: bytes): |     def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray: | ||||||
|         # Convert thumbnail bytes to PIL Image |         # Convert thumbnail bytes to PIL Image | ||||||
|         image = Image.open(io.BytesIO(thumbnail)).convert("RGB") |         image = Image.open(io.BytesIO(thumbnail)).convert("RGB") | ||||||
|         embedding = self.vision_embedding([image])[0] |         embedding = self.vision_embedding([image])[0] | ||||||
| @ -145,7 +138,25 @@ class Embeddings: | |||||||
| 
 | 
 | ||||||
|         return embedding |         return embedding | ||||||
| 
 | 
 | ||||||
|     def upsert_description(self, event_id: str, description: str): |     def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]: | ||||||
|  |         images = [ | ||||||
|  |             Image.open(io.BytesIO(thumb)).convert("RGB") | ||||||
|  |             for thumb in event_thumbs.values() | ||||||
|  |         ] | ||||||
|  |         ids = list(event_thumbs.keys()) | ||||||
|  |         embeddings = self.vision_embedding(images) | ||||||
|  |         items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))] | ||||||
|  | 
 | ||||||
|  |         self.db.execute_sql( | ||||||
|  |             """ | ||||||
|  |             INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) | ||||||
|  |             VALUES {} | ||||||
|  |             """.format(", ".join(["(?, ?)"] * len(items))), | ||||||
|  |             items, | ||||||
|  |         ) | ||||||
|  |         return embeddings | ||||||
|  | 
 | ||||||
|  |     def upsert_description(self, event_id: str, description: str) -> ndarray: | ||||||
|         embedding = self.text_embedding([description])[0] |         embedding = self.text_embedding([description])[0] | ||||||
|         self.db.execute_sql( |         self.db.execute_sql( | ||||||
|             """ |             """ | ||||||
| @ -157,6 +168,21 @@ class Embeddings: | |||||||
| 
 | 
 | ||||||
|         return embedding |         return embedding | ||||||
| 
 | 
 | ||||||
|  |     def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray: | ||||||
|  |         embeddings = self.text_embedding(list(event_descriptions.values())) | ||||||
|  |         ids = list(event_descriptions.keys()) | ||||||
|  |         items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))] | ||||||
|  | 
 | ||||||
|  |         self.db.execute_sql( | ||||||
|  |             """ | ||||||
|  |             INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) | ||||||
|  |             VALUES {} | ||||||
|  |             """.format(", ".join(["(?, ?)"] * len(items))), | ||||||
|  |             items, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         return embeddings | ||||||
|  | 
 | ||||||
|     def reindex(self) -> None: |     def reindex(self) -> None: | ||||||
|         logger.info("Indexing tracked object embeddings...") |         logger.info("Indexing tracked object embeddings...") | ||||||
| 
 | 
 | ||||||
| @ -192,9 +218,8 @@ class Embeddings: | |||||||
|         ) |         ) | ||||||
|         totals["total_objects"] = total_events |         totals["total_objects"] = total_events | ||||||
| 
 | 
 | ||||||
|         batch_size = 100 |         batch_size = 32 | ||||||
|         current_page = 1 |         current_page = 1 | ||||||
|         processed_events = 0 |  | ||||||
| 
 | 
 | ||||||
|         events = ( |         events = ( | ||||||
|             Event.select() |             Event.select() | ||||||
| @ -208,37 +233,43 @@ class Embeddings: | |||||||
| 
 | 
 | ||||||
|         while len(events) > 0: |         while len(events) > 0: | ||||||
|             event: Event |             event: Event | ||||||
|  |             batch_thumbs = {} | ||||||
|  |             batch_descs = {} | ||||||
|             for event in events: |             for event in events: | ||||||
|                 thumbnail = base64.b64decode(event.thumbnail) |                 batch_thumbs[event.id] = base64.b64decode(event.thumbnail) | ||||||
|                 self.upsert_thumbnail(event.id, thumbnail) |  | ||||||
|                 totals["thumbnails"] += 1 |                 totals["thumbnails"] += 1 | ||||||
| 
 | 
 | ||||||
|                 if description := event.data.get("description", "").strip(): |                 if description := event.data.get("description", "").strip(): | ||||||
|  |                     batch_descs[event.id] = description | ||||||
|                     totals["descriptions"] += 1 |                     totals["descriptions"] += 1 | ||||||
|                     self.upsert_description(event.id, description) |  | ||||||
| 
 | 
 | ||||||
|                 totals["processed_objects"] += 1 |                 totals["processed_objects"] += 1 | ||||||
| 
 | 
 | ||||||
|                 # report progress every 10 events so we don't spam the logs |             # run batch embedding | ||||||
|                 if (totals["processed_objects"] % 10) == 0: |             self.batch_upsert_thumbnail(batch_thumbs) | ||||||
|                     progress = (processed_events / total_events) * 100 |  | ||||||
|                     logger.debug( |  | ||||||
|                         "Processed %d/%d events (%.2f%% complete) | Thumbnails: %d, Descriptions: %d", |  | ||||||
|                         processed_events, |  | ||||||
|                         total_events, |  | ||||||
|                         progress, |  | ||||||
|                         totals["thumbnails"], |  | ||||||
|                         totals["descriptions"], |  | ||||||
|                     ) |  | ||||||
| 
 | 
 | ||||||
|                 # Calculate time remaining |             if batch_descs: | ||||||
|                 elapsed_time = time.time() - st |                 self.batch_upsert_description(batch_descs) | ||||||
|                 avg_time_per_event = elapsed_time / totals["processed_objects"] |  | ||||||
|                 remaining_events = total_events - totals["processed_objects"] |  | ||||||
|                 time_remaining = avg_time_per_event * remaining_events |  | ||||||
|                 totals["time_remaining"] = int(time_remaining) |  | ||||||
| 
 | 
 | ||||||
|                 self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals) |             # report progress every batch so we don't spam the logs | ||||||
|  |             progress = (totals["processed_objects"] / total_events) * 100 | ||||||
|  |             logger.debug( | ||||||
|  |                 "Processed %d/%d events (%.2f%% complete) | Thumbnails: %d, Descriptions: %d", | ||||||
|  |                 totals["processed_objects"], | ||||||
|  |                 total_events, | ||||||
|  |                 progress, | ||||||
|  |                 totals["thumbnails"], | ||||||
|  |                 totals["descriptions"], | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             # Calculate time remaining | ||||||
|  |             elapsed_time = time.time() - st | ||||||
|  |             avg_time_per_event = elapsed_time / totals["processed_objects"] | ||||||
|  |             remaining_events = total_events - totals["processed_objects"] | ||||||
|  |             time_remaining = avg_time_per_event * remaining_events | ||||||
|  |             totals["time_remaining"] = int(time_remaining) | ||||||
|  | 
 | ||||||
|  |             self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals) | ||||||
| 
 | 
 | ||||||
|             # Move to the next page |             # Move to the next page | ||||||
|             current_page += 1 |             current_page += 1 | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ import logging | |||||||
| import os | import os | ||||||
| import warnings | import warnings | ||||||
| from io import BytesIO | from io import BytesIO | ||||||
| from typing import Callable, Dict, List, Optional, Union | from typing import Dict, List, Optional, Union | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| import requests | import requests | ||||||
| @ -39,7 +39,6 @@ class GenericONNXEmbedding: | |||||||
|         model_name: str, |         model_name: str, | ||||||
|         model_file: str, |         model_file: str, | ||||||
|         download_urls: Dict[str, str], |         download_urls: Dict[str, str], | ||||||
|         embedding_function: Callable[[List[np.ndarray]], np.ndarray], |  | ||||||
|         model_size: str, |         model_size: str, | ||||||
|         model_type: str, |         model_type: str, | ||||||
|         requestor: InterProcessRequestor, |         requestor: InterProcessRequestor, | ||||||
| @ -51,7 +50,6 @@ class GenericONNXEmbedding: | |||||||
|         self.tokenizer_file = tokenizer_file |         self.tokenizer_file = tokenizer_file | ||||||
|         self.requestor = requestor |         self.requestor = requestor | ||||||
|         self.download_urls = download_urls |         self.download_urls = download_urls | ||||||
|         self.embedding_function = embedding_function |  | ||||||
|         self.model_type = model_type  # 'text' or 'vision' |         self.model_type = model_type  # 'text' or 'vision' | ||||||
|         self.model_size = model_size |         self.model_size = model_size | ||||||
|         self.device = device |         self.device = device | ||||||
| @ -157,7 +155,6 @@ class GenericONNXEmbedding: | |||||||
|         self, inputs: Union[List[str], List[Image.Image], List[str]] |         self, inputs: Union[List[str], List[Image.Image], List[str]] | ||||||
|     ) -> List[np.ndarray]: |     ) -> List[np.ndarray]: | ||||||
|         self._load_model_and_tokenizer() |         self._load_model_and_tokenizer() | ||||||
| 
 |  | ||||||
|         if self.runner is None or ( |         if self.runner is None or ( | ||||||
|             self.tokenizer is None and self.feature_extractor is None |             self.tokenizer is None and self.feature_extractor is None | ||||||
|         ): |         ): | ||||||
| @ -167,23 +164,27 @@ class GenericONNXEmbedding: | |||||||
|             return [] |             return [] | ||||||
| 
 | 
 | ||||||
|         if self.model_type == "text": |         if self.model_type == "text": | ||||||
|             processed_inputs = self.tokenizer( |             processed_inputs = [ | ||||||
|                 inputs, padding=True, truncation=True, return_tensors="np" |                 self.tokenizer(text, padding=True, truncation=True, return_tensors="np") | ||||||
|             ) |                 for text in inputs | ||||||
|  |             ] | ||||||
|         else: |         else: | ||||||
|             processed_images = [self._process_image(img) for img in inputs] |             processed_images = [self._process_image(img) for img in inputs] | ||||||
|             processed_inputs = self.feature_extractor( |             processed_inputs = [ | ||||||
|                 images=processed_images, return_tensors="np" |                 self.feature_extractor(images=image, return_tensors="np") | ||||||
|             ) |                 for image in processed_images | ||||||
|  |             ] | ||||||
| 
 | 
 | ||||||
|         input_names = self.runner.get_input_names() |         input_names = self.runner.get_input_names() | ||||||
|         onnx_inputs = { |         onnx_inputs = {name: [] for name in input_names} | ||||||
|             name: processed_inputs[name] |         input: dict[str, any] | ||||||
|             for name in input_names |         for input in processed_inputs: | ||||||
|             if name in processed_inputs |             for key, value in input.items(): | ||||||
|         } |                 if key in input_names: | ||||||
|  |                     onnx_inputs[key].append(value[0]) | ||||||
| 
 | 
 | ||||||
|         outputs = self.runner.run(onnx_inputs) |         for key in onnx_inputs.keys(): | ||||||
|         embeddings = self.embedding_function(outputs) |             onnx_inputs[key] = np.array(onnx_inputs[key]) | ||||||
| 
 | 
 | ||||||
|  |         embeddings = self.runner.run(onnx_inputs)[0] | ||||||
|         return [embedding for embedding in embeddings] |         return [embedding for embedding in embeddings] | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user