Batch embeddings fixes (#14325)

* fixes

* more readable loops

* more robust key check and warning message

* ensure we get reindex progress on mount

* use correct var for length
This commit is contained in:
Josh Hawkins 2024-10-13 16:25:13 -05:00 committed by GitHub
parent 66d0ad5803
commit 1ec459ea3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 25 deletions

View File

@ -145,13 +145,18 @@ class Embeddings:
] ]
ids = list(event_thumbs.keys()) ids = list(event_thumbs.keys())
embeddings = self.vision_embedding(images) embeddings = self.vision_embedding(images)
items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
items = []
for i in range(len(ids)):
items.append(ids[i])
items.append(serialize(embeddings[i]))
self.db.execute_sql( self.db.execute_sql(
""" """
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES {} VALUES {}
""".format(", ".join(["(?, ?)"] * len(items))), """.format(", ".join(["(?, ?)"] * len(ids))),
items, items,
) )
return embeddings return embeddings
@ -171,13 +176,18 @@ class Embeddings:
def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray: def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray:
embeddings = self.text_embedding(list(event_descriptions.values())) embeddings = self.text_embedding(list(event_descriptions.values()))
ids = list(event_descriptions.keys()) ids = list(event_descriptions.keys())
items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
items = []
for i in range(len(ids)):
items.append(ids[i])
items.append(serialize(embeddings[i]))
self.db.execute_sql( self.db.execute_sql(
""" """
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
VALUES {} VALUES {}
""".format(", ".join(["(?, ?)"] * len(items))), """.format(", ".join(["(?, ?)"] * len(ids))),
items, items,
) )
@ -196,16 +206,6 @@ class Embeddings:
os.remove(os.path.join(CONFIG_DIR, ".search_stats.json")) os.remove(os.path.join(CONFIG_DIR, ".search_stats.json"))
st = time.time() st = time.time()
totals = {
"thumbnails": 0,
"descriptions": 0,
"processed_objects": 0,
"total_objects": 0,
"time_remaining": 0,
"status": "indexing",
}
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
# Get total count of events to process # Get total count of events to process
total_events = ( total_events = (
@ -216,11 +216,21 @@ class Embeddings:
) )
.count() .count()
) )
totals["total_objects"] = total_events
batch_size = 32 batch_size = 32
current_page = 1 current_page = 1
totals = {
"thumbnails": 0,
"descriptions": 0,
"processed_objects": total_events - 1 if total_events < batch_size else 0,
"total_objects": total_events,
"time_remaining": 0 if total_events < batch_size else -1,
"status": "indexing",
}
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
events = ( events = (
Event.select() Event.select()
.where( .where(

View File

@ -164,8 +164,15 @@ class GenericONNXEmbedding:
return [] return []
if self.model_type == "text": if self.model_type == "text":
max_length = max(len(self.tokenizer.encode(text)) for text in inputs)
processed_inputs = [ processed_inputs = [
self.tokenizer(text, padding=True, truncation=True, return_tensors="np") self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors="np",
)
for text in inputs for text in inputs
] ]
else: else:
@ -183,8 +190,11 @@ class GenericONNXEmbedding:
if key in input_names: if key in input_names:
onnx_inputs[key].append(value[0]) onnx_inputs[key].append(value[0])
for key in onnx_inputs.keys(): for key in input_names:
onnx_inputs[key] = np.array(onnx_inputs[key]) if onnx_inputs.get(key):
onnx_inputs[key] = np.stack(onnx_inputs[key])
else:
logger.warning(f"Expected input '{key}' not found in onnx_inputs")
embeddings = self.runner.run(onnx_inputs)[0] embeddings = self.runner.run(onnx_inputs)[0]
return [embedding for embedding in embeddings] return [embedding for embedding in embeddings]

View File

@ -193,6 +193,17 @@ export default function Explore() {
// embeddings reindex progress // embeddings reindex progress
const { send: sendReindexCommand } = useWs(
"embeddings_reindex_progress",
"embeddingsReindexProgress",
);
useEffect(() => {
sendReindexCommand("embeddingsReindexProgress");
// only run on mount
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const { payload: reindexProgress } = useEmbeddingsReindexProgress(); const { payload: reindexProgress } = useEmbeddingsReindexProgress();
const embeddingsReindexing = useMemo(() => { const embeddingsReindexing = useMemo(() => {
@ -210,10 +221,10 @@ export default function Explore() {
// model states // model states
const { send: sendCommand } = useWs("model_state", "modelState"); const { send: sendModelCommand } = useWs("model_state", "modelState");
useEffect(() => { useEffect(() => {
sendCommand("modelState"); sendModelCommand("modelState");
// only run on mount // only run on mount
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, []); }, []);
@ -299,14 +310,18 @@ export default function Explore() {
/> />
</div> </div>
<div className="flex w-96 flex-col gap-2 py-5"> <div className="flex w-96 flex-col gap-2 py-5">
{reindexProgress.time_remaining >= 0 && ( {reindexProgress.time_remaining !== null && (
<div className="mb-3 flex flex-col items-center justify-center gap-1"> <div className="mb-3 flex flex-col items-center justify-center gap-1">
<div className="text-primary-variant"> <div className="text-primary-variant">
Estimated time remaining: {reindexProgress.time_remaining === -1
? "Starting up..."
: "Estimated time remaining:"}
</div> </div>
{formatSecondsToDuration( {reindexProgress.time_remaining >= 0 &&
reindexProgress.time_remaining, (formatSecondsToDuration(
) || "Finishing shortly"} reindexProgress.time_remaining,
) ||
"Finishing shortly")}
</div> </div>
)} )}
<div className="flex flex-row items-center justify-center gap-3"> <div className="flex flex-row items-center justify-center gap-3">