Support batch embeddings when reindexing (#14320)

* Refactor onnx embeddings to handle multiple inputs by default

* Process items in batches when reindexing
This commit is contained in:
Nicolas Mowen 2024-10-13 11:33:27 -06:00 committed by GitHub
parent 0fc7999780
commit e8b2fde753
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 50 deletions

View File

@ -6,6 +6,7 @@ import logging
import os import os
import time import time
from numpy import ndarray
from PIL import Image from PIL import Image
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
@ -88,12 +89,6 @@ class Embeddings:
}, },
) )
def jina_text_embedding_function(outputs):
return outputs[0]
def jina_vision_embedding_function(outputs):
return outputs[0]
self.text_embedding = GenericONNXEmbedding( self.text_embedding = GenericONNXEmbedding(
model_name="jinaai/jina-clip-v1", model_name="jinaai/jina-clip-v1",
model_file="text_model_fp16.onnx", model_file="text_model_fp16.onnx",
@ -101,7 +96,6 @@ class Embeddings:
download_urls={ download_urls={
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx", "text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
}, },
embedding_function=jina_text_embedding_function,
model_size=config.model_size, model_size=config.model_size,
model_type="text", model_type="text",
requestor=self.requestor, requestor=self.requestor,
@ -123,14 +117,13 @@ class Embeddings:
model_name="jinaai/jina-clip-v1", model_name="jinaai/jina-clip-v1",
model_file=model_file, model_file=model_file,
download_urls=download_urls, download_urls=download_urls,
embedding_function=jina_vision_embedding_function,
model_size=config.model_size, model_size=config.model_size,
model_type="vision", model_type="vision",
requestor=self.requestor, requestor=self.requestor,
device="GPU" if config.model_size == "large" else "CPU", device="GPU" if config.model_size == "large" else "CPU",
) )
def upsert_thumbnail(self, event_id: str, thumbnail: bytes): def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray:
# Convert thumbnail bytes to PIL Image # Convert thumbnail bytes to PIL Image
image = Image.open(io.BytesIO(thumbnail)).convert("RGB") image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
embedding = self.vision_embedding([image])[0] embedding = self.vision_embedding([image])[0]
@ -145,7 +138,25 @@ class Embeddings:
return embedding return embedding
def upsert_description(self, event_id: str, description: str): def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]:
images = [
Image.open(io.BytesIO(thumb)).convert("RGB")
for thumb in event_thumbs.values()
]
ids = list(event_thumbs.keys())
embeddings = self.vision_embedding(images)
items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES {}
""".format(", ".join(["(?, ?)"] * len(items))),
items,
)
return embeddings
def upsert_description(self, event_id: str, description: str) -> ndarray:
embedding = self.text_embedding([description])[0] embedding = self.text_embedding([description])[0]
self.db.execute_sql( self.db.execute_sql(
""" """
@ -157,6 +168,21 @@ class Embeddings:
return embedding return embedding
def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray:
embeddings = self.text_embedding(list(event_descriptions.values()))
ids = list(event_descriptions.keys())
items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
VALUES {}
""".format(", ".join(["(?, ?)"] * len(items))),
items,
)
return embeddings
def reindex(self) -> None: def reindex(self) -> None:
logger.info("Indexing tracked object embeddings...") logger.info("Indexing tracked object embeddings...")
@ -192,9 +218,8 @@ class Embeddings:
) )
totals["total_objects"] = total_events totals["total_objects"] = total_events
batch_size = 100 batch_size = 32
current_page = 1 current_page = 1
processed_events = 0
events = ( events = (
Event.select() Event.select()
@ -208,37 +233,43 @@ class Embeddings:
while len(events) > 0: while len(events) > 0:
event: Event event: Event
batch_thumbs = {}
batch_descs = {}
for event in events: for event in events:
thumbnail = base64.b64decode(event.thumbnail) batch_thumbs[event.id] = base64.b64decode(event.thumbnail)
self.upsert_thumbnail(event.id, thumbnail)
totals["thumbnails"] += 1 totals["thumbnails"] += 1
if description := event.data.get("description", "").strip(): if description := event.data.get("description", "").strip():
batch_descs[event.id] = description
totals["descriptions"] += 1 totals["descriptions"] += 1
self.upsert_description(event.id, description)
totals["processed_objects"] += 1 totals["processed_objects"] += 1
# report progress every 10 events so we don't spam the logs # run batch embedding
if (totals["processed_objects"] % 10) == 0: self.batch_upsert_thumbnail(batch_thumbs)
progress = (processed_events / total_events) * 100
logger.debug(
"Processed %d/%d events (%.2f%% complete) | Thumbnails: %d, Descriptions: %d",
processed_events,
total_events,
progress,
totals["thumbnails"],
totals["descriptions"],
)
# Calculate time remaining if batch_descs:
elapsed_time = time.time() - st self.batch_upsert_description(batch_descs)
avg_time_per_event = elapsed_time / totals["processed_objects"]
remaining_events = total_events - totals["processed_objects"]
time_remaining = avg_time_per_event * remaining_events
totals["time_remaining"] = int(time_remaining)
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals) # report progress every batch so we don't spam the logs
progress = (totals["processed_objects"] / total_events) * 100
logger.debug(
"Processed %d/%d events (%.2f%% complete) | Thumbnails: %d, Descriptions: %d",
totals["processed_objects"],
total_events,
progress,
totals["thumbnails"],
totals["descriptions"],
)
# Calculate time remaining
elapsed_time = time.time() - st
avg_time_per_event = elapsed_time / totals["processed_objects"]
remaining_events = total_events - totals["processed_objects"]
time_remaining = avg_time_per_event * remaining_events
totals["time_remaining"] = int(time_remaining)
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
# Move to the next page # Move to the next page
current_page += 1 current_page += 1

View File

@ -2,7 +2,7 @@ import logging
import os import os
import warnings import warnings
from io import BytesIO from io import BytesIO
from typing import Callable, Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
import requests import requests
@ -39,7 +39,6 @@ class GenericONNXEmbedding:
model_name: str, model_name: str,
model_file: str, model_file: str,
download_urls: Dict[str, str], download_urls: Dict[str, str],
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
model_size: str, model_size: str,
model_type: str, model_type: str,
requestor: InterProcessRequestor, requestor: InterProcessRequestor,
@ -51,7 +50,6 @@ class GenericONNXEmbedding:
self.tokenizer_file = tokenizer_file self.tokenizer_file = tokenizer_file
self.requestor = requestor self.requestor = requestor
self.download_urls = download_urls self.download_urls = download_urls
self.embedding_function = embedding_function
self.model_type = model_type # 'text' or 'vision' self.model_type = model_type # 'text' or 'vision'
self.model_size = model_size self.model_size = model_size
self.device = device self.device = device
@ -157,7 +155,6 @@ class GenericONNXEmbedding:
self, inputs: Union[List[str], List[Image.Image], List[str]] self, inputs: Union[List[str], List[Image.Image], List[str]]
) -> List[np.ndarray]: ) -> List[np.ndarray]:
self._load_model_and_tokenizer() self._load_model_and_tokenizer()
if self.runner is None or ( if self.runner is None or (
self.tokenizer is None and self.feature_extractor is None self.tokenizer is None and self.feature_extractor is None
): ):
@ -167,23 +164,27 @@ class GenericONNXEmbedding:
return [] return []
if self.model_type == "text": if self.model_type == "text":
processed_inputs = self.tokenizer( processed_inputs = [
inputs, padding=True, truncation=True, return_tensors="np" self.tokenizer(text, padding=True, truncation=True, return_tensors="np")
) for text in inputs
]
else: else:
processed_images = [self._process_image(img) for img in inputs] processed_images = [self._process_image(img) for img in inputs]
processed_inputs = self.feature_extractor( processed_inputs = [
images=processed_images, return_tensors="np" self.feature_extractor(images=image, return_tensors="np")
) for image in processed_images
]
input_names = self.runner.get_input_names() input_names = self.runner.get_input_names()
onnx_inputs = { onnx_inputs = {name: [] for name in input_names}
name: processed_inputs[name] input: dict[str, any]
for name in input_names for input in processed_inputs:
if name in processed_inputs for key, value in input.items():
} if key in input_names:
onnx_inputs[key].append(value[0])
outputs = self.runner.run(onnx_inputs) for key in onnx_inputs.keys():
embeddings = self.embedding_function(outputs) onnx_inputs[key] = np.array(onnx_inputs[key])
embeddings = self.runner.run(onnx_inputs)[0]
return [embedding for embedding in embeddings] return [embedding for embedding in embeddings]