Support batch embeddings when reindexing (#14320)

* Refactor onnx embeddings to handle multiple inputs by default

* Process items in batches when reindexing
This commit is contained in:
Nicolas Mowen 2024-10-13 11:33:27 -06:00 committed by GitHub
parent 0fc7999780
commit e8b2fde753
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 50 deletions

View File

@ -6,6 +6,7 @@ import logging
import os
import time
from numpy import ndarray
from PIL import Image
from playhouse.shortcuts import model_to_dict
@ -88,12 +89,6 @@ class Embeddings:
},
)
def jina_text_embedding_function(outputs):
return outputs[0]
def jina_vision_embedding_function(outputs):
return outputs[0]
self.text_embedding = GenericONNXEmbedding(
model_name="jinaai/jina-clip-v1",
model_file="text_model_fp16.onnx",
@ -101,7 +96,6 @@ class Embeddings:
download_urls={
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
},
embedding_function=jina_text_embedding_function,
model_size=config.model_size,
model_type="text",
requestor=self.requestor,
@ -123,14 +117,13 @@ class Embeddings:
model_name="jinaai/jina-clip-v1",
model_file=model_file,
download_urls=download_urls,
embedding_function=jina_vision_embedding_function,
model_size=config.model_size,
model_type="vision",
requestor=self.requestor,
device="GPU" if config.model_size == "large" else "CPU",
)
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray:
# Convert thumbnail bytes to PIL Image
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
embedding = self.vision_embedding([image])[0]
@ -145,7 +138,25 @@ class Embeddings:
return embedding
def upsert_description(self, event_id: str, description: str):
def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]:
images = [
Image.open(io.BytesIO(thumb)).convert("RGB")
for thumb in event_thumbs.values()
]
ids = list(event_thumbs.keys())
embeddings = self.vision_embedding(images)
items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES {}
""".format(", ".join(["(?, ?)"] * len(items))),
items,
)
return embeddings
def upsert_description(self, event_id: str, description: str) -> ndarray:
embedding = self.text_embedding([description])[0]
self.db.execute_sql(
"""
@ -157,6 +168,21 @@ class Embeddings:
return embedding
def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray:
embeddings = self.text_embedding(list(event_descriptions.values()))
ids = list(event_descriptions.keys())
items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
VALUES {}
""".format(", ".join(["(?, ?)"] * len(items))),
items,
)
return embeddings
def reindex(self) -> None:
logger.info("Indexing tracked object embeddings...")
@ -192,9 +218,8 @@ class Embeddings:
)
totals["total_objects"] = total_events
batch_size = 100
batch_size = 32
current_page = 1
processed_events = 0
events = (
Event.select()
@ -208,37 +233,43 @@ class Embeddings:
while len(events) > 0:
event: Event
batch_thumbs = {}
batch_descs = {}
for event in events:
thumbnail = base64.b64decode(event.thumbnail)
self.upsert_thumbnail(event.id, thumbnail)
batch_thumbs[event.id] = base64.b64decode(event.thumbnail)
totals["thumbnails"] += 1
if description := event.data.get("description", "").strip():
batch_descs[event.id] = description
totals["descriptions"] += 1
self.upsert_description(event.id, description)
totals["processed_objects"] += 1
# report progress every 10 events so we don't spam the logs
if (totals["processed_objects"] % 10) == 0:
progress = (processed_events / total_events) * 100
logger.debug(
"Processed %d/%d events (%.2f%% complete) | Thumbnails: %d, Descriptions: %d",
processed_events,
total_events,
progress,
totals["thumbnails"],
totals["descriptions"],
)
# run batch embedding
self.batch_upsert_thumbnail(batch_thumbs)
# Calculate time remaining
elapsed_time = time.time() - st
avg_time_per_event = elapsed_time / totals["processed_objects"]
remaining_events = total_events - totals["processed_objects"]
time_remaining = avg_time_per_event * remaining_events
totals["time_remaining"] = int(time_remaining)
if batch_descs:
self.batch_upsert_description(batch_descs)
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
# report progress every batch so we don't spam the logs
progress = (totals["processed_objects"] / total_events) * 100
logger.debug(
"Processed %d/%d events (%.2f%% complete) | Thumbnails: %d, Descriptions: %d",
totals["processed_objects"],
total_events,
progress,
totals["thumbnails"],
totals["descriptions"],
)
# Calculate time remaining
elapsed_time = time.time() - st
avg_time_per_event = elapsed_time / totals["processed_objects"]
remaining_events = total_events - totals["processed_objects"]
time_remaining = avg_time_per_event * remaining_events
totals["time_remaining"] = int(time_remaining)
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
# Move to the next page
current_page += 1

View File

@ -2,7 +2,7 @@ import logging
import os
import warnings
from io import BytesIO
from typing import Callable, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union
import numpy as np
import requests
@ -39,7 +39,6 @@ class GenericONNXEmbedding:
model_name: str,
model_file: str,
download_urls: Dict[str, str],
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
model_size: str,
model_type: str,
requestor: InterProcessRequestor,
@ -51,7 +50,6 @@ class GenericONNXEmbedding:
self.tokenizer_file = tokenizer_file
self.requestor = requestor
self.download_urls = download_urls
self.embedding_function = embedding_function
self.model_type = model_type # 'text' or 'vision'
self.model_size = model_size
self.device = device
@ -157,7 +155,6 @@ class GenericONNXEmbedding:
self, inputs: Union[List[str], List[Image.Image], List[str]]
) -> List[np.ndarray]:
self._load_model_and_tokenizer()
if self.runner is None or (
self.tokenizer is None and self.feature_extractor is None
):
@ -167,23 +164,27 @@ class GenericONNXEmbedding:
return []
if self.model_type == "text":
processed_inputs = self.tokenizer(
inputs, padding=True, truncation=True, return_tensors="np"
)
processed_inputs = [
self.tokenizer(text, padding=True, truncation=True, return_tensors="np")
for text in inputs
]
else:
processed_images = [self._process_image(img) for img in inputs]
processed_inputs = self.feature_extractor(
images=processed_images, return_tensors="np"
)
processed_inputs = [
self.feature_extractor(images=image, return_tensors="np")
for image in processed_images
]
input_names = self.runner.get_input_names()
onnx_inputs = {
name: processed_inputs[name]
for name in input_names
if name in processed_inputs
}
onnx_inputs = {name: [] for name in input_names}
input: dict[str, any]
for input in processed_inputs:
for key, value in input.items():
if key in input_names:
onnx_inputs[key].append(value[0])
outputs = self.runner.run(onnx_inputs)
embeddings = self.embedding_function(outputs)
for key in onnx_inputs.keys():
onnx_inputs[key] = np.array(onnx_inputs[key])
embeddings = self.runner.run(onnx_inputs)[0]
return [embedding for embedding in embeddings]