mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-12-23 19:11:14 +01:00
Batch embeddings fixes (#14325)
* fixes * more readable loops * more robust key check and warning message * ensure we get reindex progress on mount * use correct var for length
This commit is contained in:
parent
66d0ad5803
commit
1ec459ea3a
@ -145,13 +145,18 @@ class Embeddings:
|
|||||||
]
|
]
|
||||||
ids = list(event_thumbs.keys())
|
ids = list(event_thumbs.keys())
|
||||||
embeddings = self.vision_embedding(images)
|
embeddings = self.vision_embedding(images)
|
||||||
items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
|
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for i in range(len(ids)):
|
||||||
|
items.append(ids[i])
|
||||||
|
items.append(serialize(embeddings[i]))
|
||||||
|
|
||||||
self.db.execute_sql(
|
self.db.execute_sql(
|
||||||
"""
|
"""
|
||||||
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
||||||
VALUES {}
|
VALUES {}
|
||||||
""".format(", ".join(["(?, ?)"] * len(items))),
|
""".format(", ".join(["(?, ?)"] * len(ids))),
|
||||||
items,
|
items,
|
||||||
)
|
)
|
||||||
return embeddings
|
return embeddings
|
||||||
@ -171,13 +176,18 @@ class Embeddings:
|
|||||||
def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray:
|
def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray:
|
||||||
embeddings = self.text_embedding(list(event_descriptions.values()))
|
embeddings = self.text_embedding(list(event_descriptions.values()))
|
||||||
ids = list(event_descriptions.keys())
|
ids = list(event_descriptions.keys())
|
||||||
items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
|
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for i in range(len(ids)):
|
||||||
|
items.append(ids[i])
|
||||||
|
items.append(serialize(embeddings[i]))
|
||||||
|
|
||||||
self.db.execute_sql(
|
self.db.execute_sql(
|
||||||
"""
|
"""
|
||||||
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||||
VALUES {}
|
VALUES {}
|
||||||
""".format(", ".join(["(?, ?)"] * len(items))),
|
""".format(", ".join(["(?, ?)"] * len(ids))),
|
||||||
items,
|
items,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -196,16 +206,6 @@ class Embeddings:
|
|||||||
os.remove(os.path.join(CONFIG_DIR, ".search_stats.json"))
|
os.remove(os.path.join(CONFIG_DIR, ".search_stats.json"))
|
||||||
|
|
||||||
st = time.time()
|
st = time.time()
|
||||||
totals = {
|
|
||||||
"thumbnails": 0,
|
|
||||||
"descriptions": 0,
|
|
||||||
"processed_objects": 0,
|
|
||||||
"total_objects": 0,
|
|
||||||
"time_remaining": 0,
|
|
||||||
"status": "indexing",
|
|
||||||
}
|
|
||||||
|
|
||||||
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
|
|
||||||
|
|
||||||
# Get total count of events to process
|
# Get total count of events to process
|
||||||
total_events = (
|
total_events = (
|
||||||
@ -216,11 +216,21 @@ class Embeddings:
|
|||||||
)
|
)
|
||||||
.count()
|
.count()
|
||||||
)
|
)
|
||||||
totals["total_objects"] = total_events
|
|
||||||
|
|
||||||
batch_size = 32
|
batch_size = 32
|
||||||
current_page = 1
|
current_page = 1
|
||||||
|
|
||||||
|
totals = {
|
||||||
|
"thumbnails": 0,
|
||||||
|
"descriptions": 0,
|
||||||
|
"processed_objects": total_events - 1 if total_events < batch_size else 0,
|
||||||
|
"total_objects": total_events,
|
||||||
|
"time_remaining": 0 if total_events < batch_size else -1,
|
||||||
|
"status": "indexing",
|
||||||
|
}
|
||||||
|
|
||||||
|
self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
|
||||||
|
|
||||||
events = (
|
events = (
|
||||||
Event.select()
|
Event.select()
|
||||||
.where(
|
.where(
|
||||||
|
@ -164,8 +164,15 @@ class GenericONNXEmbedding:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
if self.model_type == "text":
|
if self.model_type == "text":
|
||||||
|
max_length = max(len(self.tokenizer.encode(text)) for text in inputs)
|
||||||
processed_inputs = [
|
processed_inputs = [
|
||||||
self.tokenizer(text, padding=True, truncation=True, return_tensors="np")
|
self.tokenizer(
|
||||||
|
text,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
for text in inputs
|
for text in inputs
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
@ -183,8 +190,11 @@ class GenericONNXEmbedding:
|
|||||||
if key in input_names:
|
if key in input_names:
|
||||||
onnx_inputs[key].append(value[0])
|
onnx_inputs[key].append(value[0])
|
||||||
|
|
||||||
for key in onnx_inputs.keys():
|
for key in input_names:
|
||||||
onnx_inputs[key] = np.array(onnx_inputs[key])
|
if onnx_inputs.get(key):
|
||||||
|
onnx_inputs[key] = np.stack(onnx_inputs[key])
|
||||||
|
else:
|
||||||
|
logger.warning(f"Expected input '{key}' not found in onnx_inputs")
|
||||||
|
|
||||||
embeddings = self.runner.run(onnx_inputs)[0]
|
embeddings = self.runner.run(onnx_inputs)[0]
|
||||||
return [embedding for embedding in embeddings]
|
return [embedding for embedding in embeddings]
|
||||||
|
@ -193,6 +193,17 @@ export default function Explore() {
|
|||||||
|
|
||||||
// embeddings reindex progress
|
// embeddings reindex progress
|
||||||
|
|
||||||
|
const { send: sendReindexCommand } = useWs(
|
||||||
|
"embeddings_reindex_progress",
|
||||||
|
"embeddingsReindexProgress",
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
sendReindexCommand("embeddingsReindexProgress");
|
||||||
|
// only run on mount
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, []);
|
||||||
|
|
||||||
const { payload: reindexProgress } = useEmbeddingsReindexProgress();
|
const { payload: reindexProgress } = useEmbeddingsReindexProgress();
|
||||||
|
|
||||||
const embeddingsReindexing = useMemo(() => {
|
const embeddingsReindexing = useMemo(() => {
|
||||||
@ -210,10 +221,10 @@ export default function Explore() {
|
|||||||
|
|
||||||
// model states
|
// model states
|
||||||
|
|
||||||
const { send: sendCommand } = useWs("model_state", "modelState");
|
const { send: sendModelCommand } = useWs("model_state", "modelState");
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
sendCommand("modelState");
|
sendModelCommand("modelState");
|
||||||
// only run on mount
|
// only run on mount
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, []);
|
}, []);
|
||||||
@ -299,14 +310,18 @@ export default function Explore() {
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div className="flex w-96 flex-col gap-2 py-5">
|
<div className="flex w-96 flex-col gap-2 py-5">
|
||||||
{reindexProgress.time_remaining >= 0 && (
|
{reindexProgress.time_remaining !== null && (
|
||||||
<div className="mb-3 flex flex-col items-center justify-center gap-1">
|
<div className="mb-3 flex flex-col items-center justify-center gap-1">
|
||||||
<div className="text-primary-variant">
|
<div className="text-primary-variant">
|
||||||
Estimated time remaining:
|
{reindexProgress.time_remaining === -1
|
||||||
|
? "Starting up..."
|
||||||
|
: "Estimated time remaining:"}
|
||||||
</div>
|
</div>
|
||||||
{formatSecondsToDuration(
|
{reindexProgress.time_remaining >= 0 &&
|
||||||
|
(formatSecondsToDuration(
|
||||||
reindexProgress.time_remaining,
|
reindexProgress.time_remaining,
|
||||||
) || "Finishing shortly"}
|
) ||
|
||||||
|
"Finishing shortly")}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
<div className="flex flex-row items-center justify-center gap-3">
|
<div className="flex flex-row items-center justify-center gap-3">
|
||||||
|
Loading…
Reference in New Issue
Block a user