mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-21 19:07:46 +01:00
Support batch embeddings when reindexing (#14320)
* Refactor onnx embeddings to handle multiple inputs by default * Process items in batches when reindexing
This commit is contained in:
parent
0fc7999780
commit
e8b2fde753
@ -6,6 +6,7 @@ import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from numpy import ndarray
|
||||
from PIL import Image
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
|
||||
@ -88,12 +89,6 @@ class Embeddings:
|
||||
},
|
||||
)
|
||||
|
||||
def jina_text_embedding_function(outputs):
|
||||
return outputs[0]
|
||||
|
||||
def jina_vision_embedding_function(outputs):
|
||||
return outputs[0]
|
||||
|
||||
self.text_embedding = GenericONNXEmbedding(
|
||||
model_name="jinaai/jina-clip-v1",
|
||||
model_file="text_model_fp16.onnx",
|
||||
@ -101,7 +96,6 @@ class Embeddings:
|
||||
download_urls={
|
||||
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
|
||||
},
|
||||
embedding_function=jina_text_embedding_function,
|
||||
model_size=config.model_size,
|
||||
model_type="text",
|
||||
requestor=self.requestor,
|
||||
@ -123,14 +117,13 @@ class Embeddings:
|
||||
model_name="jinaai/jina-clip-v1",
|
||||
model_file=model_file,
|
||||
download_urls=download_urls,
|
||||
embedding_function=jina_vision_embedding_function,
|
||||
model_size=config.model_size,
|
||||
model_type="vision",
|
||||
requestor=self.requestor,
|
||||
device="GPU" if config.model_size == "large" else "CPU",
|
||||
)
|
||||
|
||||
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
|
||||
def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray:
|
||||
# Convert thumbnail bytes to PIL Image
|
||||
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
||||
embedding = self.vision_embedding([image])[0]
|
||||
@ -145,7 +138,25 @@ class Embeddings:
|
||||
|
||||
return embedding
|
||||
|
||||
def upsert_description(self, event_id: str, description: str):
|
||||
def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]:
|
||||
images = [
|
||||
Image.open(io.BytesIO(thumb)).convert("RGB")
|
||||
for thumb in event_thumbs.values()
|
||||
]
|
||||
ids = list(event_thumbs.keys())
|
||||
embeddings = self.vision_embedding(images)
|
||||
items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
|
||||
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
||||
VALUES {}
|
||||
""".format(", ".join(["(?, ?)"] * len(items))),
|
||||
items,
|
||||
)
|
||||
return embeddings
|
||||
|
||||
def upsert_description(self, event_id: str, description: str) -> ndarray:
|
||||
embedding = self.text_embedding([description])[0]
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
@ -157,6 +168,21 @@ class Embeddings:
|
||||
|
||||
return embedding
|
||||
|
||||
def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray:
|
||||
embeddings = self.text_embedding(list(event_descriptions.values()))
|
||||
ids = list(event_descriptions.keys())
|
||||
items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
|
||||
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||
VALUES {}
|
||||
""".format(", ".join(["(?, ?)"] * len(items))),
|
||||
items,
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
def reindex(self) -> None:
|
||||
logger.info("Indexing tracked object embeddings...")
|
||||
|
||||
@ -192,9 +218,8 @@ class Embeddings:
|
||||
)
|
||||
totals["total_objects"] = total_events
|
||||
|
||||
batch_size = 100
|
||||
batch_size = 32
|
||||
current_page = 1
|
||||
processed_events = 0
|
||||
|
||||
events = (
|
||||
Event.select()
|
||||
@ -208,23 +233,29 @@ class Embeddings:
|
||||
|
||||
while len(events) > 0:
|
||||
event: Event
|
||||
batch_thumbs = {}
|
||||
batch_descs = {}
|
||||
for event in events:
|
||||
thumbnail = base64.b64decode(event.thumbnail)
|
||||
self.upsert_thumbnail(event.id, thumbnail)
|
||||
batch_thumbs[event.id] = base64.b64decode(event.thumbnail)
|
||||
totals["thumbnails"] += 1
|
||||
|
||||
if description := event.data.get("description", "").strip():
|
||||
batch_descs[event.id] = description
|
||||
totals["descriptions"] += 1
|
||||
self.upsert_description(event.id, description)
|
||||
|
||||
totals["processed_objects"] += 1
|
||||
|
||||
# report progress every 10 events so we don't spam the logs
|
||||
if (totals["processed_objects"] % 10) == 0:
|
||||
progress = (processed_events / total_events) * 100
|
||||
# run batch embedding
|
||||
self.batch_upsert_thumbnail(batch_thumbs)
|
||||
|
||||
if batch_descs:
|
||||
self.batch_upsert_description(batch_descs)
|
||||
|
||||
# report progress every batch so we don't spam the logs
|
||||
progress = (totals["processed_objects"] / total_events) * 100
|
||||
logger.debug(
|
||||
"Processed %d/%d events (%.2f%% complete) | Thumbnails: %d, Descriptions: %d",
|
||||
processed_events,
|
||||
totals["processed_objects"],
|
||||
total_events,
|
||||
progress,
|
||||
totals["thumbnails"],
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@ -39,7 +39,6 @@ class GenericONNXEmbedding:
|
||||
model_name: str,
|
||||
model_file: str,
|
||||
download_urls: Dict[str, str],
|
||||
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
|
||||
model_size: str,
|
||||
model_type: str,
|
||||
requestor: InterProcessRequestor,
|
||||
@ -51,7 +50,6 @@ class GenericONNXEmbedding:
|
||||
self.tokenizer_file = tokenizer_file
|
||||
self.requestor = requestor
|
||||
self.download_urls = download_urls
|
||||
self.embedding_function = embedding_function
|
||||
self.model_type = model_type # 'text' or 'vision'
|
||||
self.model_size = model_size
|
||||
self.device = device
|
||||
@ -157,7 +155,6 @@ class GenericONNXEmbedding:
|
||||
self, inputs: Union[List[str], List[Image.Image], List[str]]
|
||||
) -> List[np.ndarray]:
|
||||
self._load_model_and_tokenizer()
|
||||
|
||||
if self.runner is None or (
|
||||
self.tokenizer is None and self.feature_extractor is None
|
||||
):
|
||||
@ -167,23 +164,27 @@ class GenericONNXEmbedding:
|
||||
return []
|
||||
|
||||
if self.model_type == "text":
|
||||
processed_inputs = self.tokenizer(
|
||||
inputs, padding=True, truncation=True, return_tensors="np"
|
||||
)
|
||||
processed_inputs = [
|
||||
self.tokenizer(text, padding=True, truncation=True, return_tensors="np")
|
||||
for text in inputs
|
||||
]
|
||||
else:
|
||||
processed_images = [self._process_image(img) for img in inputs]
|
||||
processed_inputs = self.feature_extractor(
|
||||
images=processed_images, return_tensors="np"
|
||||
)
|
||||
processed_inputs = [
|
||||
self.feature_extractor(images=image, return_tensors="np")
|
||||
for image in processed_images
|
||||
]
|
||||
|
||||
input_names = self.runner.get_input_names()
|
||||
onnx_inputs = {
|
||||
name: processed_inputs[name]
|
||||
for name in input_names
|
||||
if name in processed_inputs
|
||||
}
|
||||
onnx_inputs = {name: [] for name in input_names}
|
||||
input: dict[str, any]
|
||||
for input in processed_inputs:
|
||||
for key, value in input.items():
|
||||
if key in input_names:
|
||||
onnx_inputs[key].append(value[0])
|
||||
|
||||
outputs = self.runner.run(onnx_inputs)
|
||||
embeddings = self.embedding_function(outputs)
|
||||
for key in onnx_inputs.keys():
|
||||
onnx_inputs[key] = np.array(onnx_inputs[key])
|
||||
|
||||
embeddings = self.runner.run(onnx_inputs)[0]
|
||||
return [embedding for embedding in embeddings]
|
||||
|
Loading…
Reference in New Issue
Block a user