mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-07-12 13:47:14 +02:00
Support batch embeddings when reindexing (#14320)
* Refactor onnx embeddings to handle multiple inputs by default * Process items in batches when reindexing
This commit is contained in:
parent
0fc7999780
commit
e8b2fde753
@ -6,6 +6,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from numpy import ndarray
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from playhouse.shortcuts import model_to_dict
|
from playhouse.shortcuts import model_to_dict
|
||||||
|
|
||||||
@ -88,12 +89,6 @@ class Embeddings:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def jina_text_embedding_function(outputs):
|
|
||||||
return outputs[0]
|
|
||||||
|
|
||||||
def jina_vision_embedding_function(outputs):
|
|
||||||
return outputs[0]
|
|
||||||
|
|
||||||
self.text_embedding = GenericONNXEmbedding(
|
self.text_embedding = GenericONNXEmbedding(
|
||||||
model_name="jinaai/jina-clip-v1",
|
model_name="jinaai/jina-clip-v1",
|
||||||
model_file="text_model_fp16.onnx",
|
model_file="text_model_fp16.onnx",
|
||||||
@ -101,7 +96,6 @@ class Embeddings:
|
|||||||
download_urls={
|
download_urls={
|
||||||
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
|
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
|
||||||
},
|
},
|
||||||
embedding_function=jina_text_embedding_function,
|
|
||||||
model_size=config.model_size,
|
model_size=config.model_size,
|
||||||
model_type="text",
|
model_type="text",
|
||||||
requestor=self.requestor,
|
requestor=self.requestor,
|
||||||
@ -123,14 +117,13 @@ class Embeddings:
|
|||||||
model_name="jinaai/jina-clip-v1",
|
model_name="jinaai/jina-clip-v1",
|
||||||
model_file=model_file,
|
model_file=model_file,
|
||||||
download_urls=download_urls,
|
download_urls=download_urls,
|
||||||
embedding_function=jina_vision_embedding_function,
|
|
||||||
model_size=config.model_size,
|
model_size=config.model_size,
|
||||||
model_type="vision",
|
model_type="vision",
|
||||||
requestor=self.requestor,
|
requestor=self.requestor,
|
||||||
device="GPU" if config.model_size == "large" else "CPU",
|
device="GPU" if config.model_size == "large" else "CPU",
|
||||||
)
|
)
|
||||||
|
|
||||||
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
|
def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray:
|
||||||
# Convert thumbnail bytes to PIL Image
|
# Convert thumbnail bytes to PIL Image
|
||||||
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
||||||
embedding = self.vision_embedding([image])[0]
|
embedding = self.vision_embedding([image])[0]
|
||||||
@ -145,7 +138,25 @@ class Embeddings:
|
|||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def upsert_description(self, event_id: str, description: str):
|
def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]:
|
||||||
|
images = [
|
||||||
|
Image.open(io.BytesIO(thumb)).convert("RGB")
|
||||||
|
for thumb in event_thumbs.values()
|
||||||
|
]
|
||||||
|
ids = list(event_thumbs.keys())
|
||||||
|
embeddings = self.vision_embedding(images)
|
||||||
|
items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
|
||||||
|
|
||||||
|
self.db.execute_sql(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
||||||
|
VALUES {}
|
||||||
|
""".format(", ".join(["(?, ?)"] * len(items))),
|
||||||
|
items,
|
||||||
|
)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def upsert_description(self, event_id: str, description: str) -> ndarray:
|
||||||
embedding = self.text_embedding([description])[0]
|
embedding = self.text_embedding([description])[0]
|
||||||
self.db.execute_sql(
|
self.db.execute_sql(
|
||||||
"""
|
"""
|
||||||
@ -157,6 +168,21 @@ class Embeddings:
|
|||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray:
|
||||||
|
embeddings = self.text_embedding(list(event_descriptions.values()))
|
||||||
|
ids = list(event_descriptions.keys())
|
||||||
|
items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
|
||||||
|
|
||||||
|
self.db.execute_sql(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||||
|
VALUES {}
|
||||||
|
""".format(", ".join(["(?, ?)"] * len(items))),
|
||||||
|
items,
|
||||||
|
)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
def reindex(self) -> None:
|
def reindex(self) -> None:
|
||||||
logger.info("Indexing tracked object embeddings...")
|
logger.info("Indexing tracked object embeddings...")
|
||||||
|
|
||||||
@ -192,9 +218,8 @@ class Embeddings:
|
|||||||
)
|
)
|
||||||
totals["total_objects"] = total_events
|
totals["total_objects"] = total_events
|
||||||
|
|
||||||
batch_size = 100
|
batch_size = 32
|
||||||
current_page = 1
|
current_page = 1
|
||||||
processed_events = 0
|
|
||||||
|
|
||||||
events = (
|
events = (
|
||||||
Event.select()
|
Event.select()
|
||||||
@ -208,37 +233,43 @@ class Embeddings:
|
|||||||
|
|
||||||
while len(events) > 0:
|
while len(events) > 0:
|
||||||
event: Event
|
event: Event
|
||||||
|
batch_thumbs = {}
|
||||||
|
batch_descs = {}
|
||||||
for event in events:
|
for event in events:
|
||||||
thumbnail = base64.b64decode(event.thumbnail)
|
batch_thumbs[event.id] = base64.b64decode(event.thumbnail)
|
||||||
self.upsert_thumbnail(event.id, thumbnail)
|
|
||||||
totals["thumbnails"] += 1
|
totals["thumbnails"] += 1
|
||||||
|
|
||||||
if description := event.data.get("description", "").strip():
|
if description := event.data.get("description", "").strip():
|
||||||
|
batch_descs[event.id] = description
|
||||||
totals["descriptions"] += 1
|
totals["descriptions"] += 1
|
||||||
self.upsert_description(event.id, description)
|
|
||||||
|
|
||||||
totals["processed_objects"] += 1
|
totals["processed_objects"] += 1
|
||||||
|
|
||||||
# report progress every 10 events so we don't spam the logs
|
# run batch embedding
|
||||||
if (totals["processed_objects"] % 10) == 0:
|
self.batch_upsert_thumbnail(batch_thumbs)
|
||||||
progress = (processed_events / total_events) * 100
|
|
||||||
logger.debug(
|
|
||||||
"Processed %d/%d events (%.2f%% complete) | Thumbnails: %d, Descriptions: %d",
|
|
||||||
processed_events,
|
|
||||||
total_events,
|
|
||||||
progress,
|
|
||||||
totals["thumbnails"],
|
|
||||||
totals["descriptions"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate time remaining
|
if batch_descs:
|
||||||
elapsed_time = time.time() - st
|
self.batch_upsert_description(batch_descs)
|
||||||
avg_time_per_event = elapsed_time / totals["processed_objects"]
|
|
||||||
remaining_events = total_events - totals["processed_objects"]
|
|
||||||
time_remaining = avg_time_per_event * remaining_events
|
|
||||||
totals["time_remaining"] = int(time_remaining)
|
|
||||||
|
|
||||||
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
|
# report progress every batch so we don't spam the logs
|
||||||
|
progress = (totals["processed_objects"] / total_events) * 100
|
||||||
|
logger.debug(
|
||||||
|
"Processed %d/%d events (%.2f%% complete) | Thumbnails: %d, Descriptions: %d",
|
||||||
|
totals["processed_objects"],
|
||||||
|
total_events,
|
||||||
|
progress,
|
||||||
|
totals["thumbnails"],
|
||||||
|
totals["descriptions"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate time remaining
|
||||||
|
elapsed_time = time.time() - st
|
||||||
|
avg_time_per_event = elapsed_time / totals["processed_objects"]
|
||||||
|
remaining_events = total_events - totals["processed_objects"]
|
||||||
|
time_remaining = avg_time_per_event * remaining_events
|
||||||
|
totals["time_remaining"] = int(time_remaining)
|
||||||
|
|
||||||
|
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
|
||||||
|
|
||||||
# Move to the next page
|
# Move to the next page
|
||||||
current_page += 1
|
current_page += 1
|
||||||
|
@ -2,7 +2,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
@ -39,7 +39,6 @@ class GenericONNXEmbedding:
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
model_file: str,
|
model_file: str,
|
||||||
download_urls: Dict[str, str],
|
download_urls: Dict[str, str],
|
||||||
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
|
|
||||||
model_size: str,
|
model_size: str,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
requestor: InterProcessRequestor,
|
requestor: InterProcessRequestor,
|
||||||
@ -51,7 +50,6 @@ class GenericONNXEmbedding:
|
|||||||
self.tokenizer_file = tokenizer_file
|
self.tokenizer_file = tokenizer_file
|
||||||
self.requestor = requestor
|
self.requestor = requestor
|
||||||
self.download_urls = download_urls
|
self.download_urls = download_urls
|
||||||
self.embedding_function = embedding_function
|
|
||||||
self.model_type = model_type # 'text' or 'vision'
|
self.model_type = model_type # 'text' or 'vision'
|
||||||
self.model_size = model_size
|
self.model_size = model_size
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -157,7 +155,6 @@ class GenericONNXEmbedding:
|
|||||||
self, inputs: Union[List[str], List[Image.Image], List[str]]
|
self, inputs: Union[List[str], List[Image.Image], List[str]]
|
||||||
) -> List[np.ndarray]:
|
) -> List[np.ndarray]:
|
||||||
self._load_model_and_tokenizer()
|
self._load_model_and_tokenizer()
|
||||||
|
|
||||||
if self.runner is None or (
|
if self.runner is None or (
|
||||||
self.tokenizer is None and self.feature_extractor is None
|
self.tokenizer is None and self.feature_extractor is None
|
||||||
):
|
):
|
||||||
@ -167,23 +164,27 @@ class GenericONNXEmbedding:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
if self.model_type == "text":
|
if self.model_type == "text":
|
||||||
processed_inputs = self.tokenizer(
|
processed_inputs = [
|
||||||
inputs, padding=True, truncation=True, return_tensors="np"
|
self.tokenizer(text, padding=True, truncation=True, return_tensors="np")
|
||||||
)
|
for text in inputs
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
processed_images = [self._process_image(img) for img in inputs]
|
processed_images = [self._process_image(img) for img in inputs]
|
||||||
processed_inputs = self.feature_extractor(
|
processed_inputs = [
|
||||||
images=processed_images, return_tensors="np"
|
self.feature_extractor(images=image, return_tensors="np")
|
||||||
)
|
for image in processed_images
|
||||||
|
]
|
||||||
|
|
||||||
input_names = self.runner.get_input_names()
|
input_names = self.runner.get_input_names()
|
||||||
onnx_inputs = {
|
onnx_inputs = {name: [] for name in input_names}
|
||||||
name: processed_inputs[name]
|
input: dict[str, any]
|
||||||
for name in input_names
|
for input in processed_inputs:
|
||||||
if name in processed_inputs
|
for key, value in input.items():
|
||||||
}
|
if key in input_names:
|
||||||
|
onnx_inputs[key].append(value[0])
|
||||||
|
|
||||||
outputs = self.runner.run(onnx_inputs)
|
for key in onnx_inputs.keys():
|
||||||
embeddings = self.embedding_function(outputs)
|
onnx_inputs[key] = np.array(onnx_inputs[key])
|
||||||
|
|
||||||
|
embeddings = self.runner.run(onnx_inputs)[0]
|
||||||
return [embedding for embedding in embeddings]
|
return [embedding for embedding in embeddings]
|
||||||
|
Loading…
Reference in New Issue
Block a user