mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-07-26 13:47:03 +02:00
Various fixes and improvements (#14492)
* Refactor preprocessing of images * Cleanup preprocessing * Improve naming and handling of embeddings * Handle invalid intel json * remove unused * Use enum for model types * Formatting
This commit is contained in:
parent
b69816c2f9
commit
40c6fda19d
@ -1,13 +1,11 @@
|
|||||||
"""SQLite-vec embeddings database."""
|
"""SQLite-vec embeddings database."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from numpy import ndarray
|
from numpy import ndarray
|
||||||
from PIL import Image
|
|
||||||
from playhouse.shortcuts import model_to_dict
|
from playhouse.shortcuts import model_to_dict
|
||||||
|
|
||||||
from frigate.comms.inter_process import InterProcessRequestor
|
from frigate.comms.inter_process import InterProcessRequestor
|
||||||
@ -22,7 +20,7 @@ from frigate.models import Event
|
|||||||
from frigate.types import ModelStatusTypesEnum
|
from frigate.types import ModelStatusTypesEnum
|
||||||
from frigate.util.builtin import serialize
|
from frigate.util.builtin import serialize
|
||||||
|
|
||||||
from .functions.onnx import GenericONNXEmbedding
|
from .functions.onnx import GenericONNXEmbedding, ModelTypeEnum
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -97,7 +95,7 @@ class Embeddings:
|
|||||||
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
|
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
|
||||||
},
|
},
|
||||||
model_size=config.model_size,
|
model_size=config.model_size,
|
||||||
model_type="text",
|
model_type=ModelTypeEnum.text,
|
||||||
requestor=self.requestor,
|
requestor=self.requestor,
|
||||||
device="CPU",
|
device="CPU",
|
||||||
)
|
)
|
||||||
@ -118,83 +116,102 @@ class Embeddings:
|
|||||||
model_file=model_file,
|
model_file=model_file,
|
||||||
download_urls=download_urls,
|
download_urls=download_urls,
|
||||||
model_size=config.model_size,
|
model_size=config.model_size,
|
||||||
model_type="vision",
|
model_type=ModelTypeEnum.vision,
|
||||||
requestor=self.requestor,
|
requestor=self.requestor,
|
||||||
device="GPU" if config.model_size == "large" else "CPU",
|
device="GPU" if config.model_size == "large" else "CPU",
|
||||||
)
|
)
|
||||||
|
|
||||||
def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray:
|
def embed_thumbnail(
|
||||||
# Convert thumbnail bytes to PIL Image
|
self, event_id: str, thumbnail: bytes, upsert: bool = True
|
||||||
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
) -> ndarray:
|
||||||
embedding = self.vision_embedding([image])[0]
|
"""Embed thumbnail and optionally insert into DB.
|
||||||
|
|
||||||
self.db.execute_sql(
|
@param: event_id in Events DB
|
||||||
"""
|
@param: thumbnail bytes in jpg format
|
||||||
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
@param: upsert If embedding should be upserted into vec DB
|
||||||
VALUES(?, ?)
|
"""
|
||||||
""",
|
# Convert thumbnail bytes to PIL Image
|
||||||
(event_id, serialize(embedding)),
|
embedding = self.vision_embedding([thumbnail])[0]
|
||||||
)
|
|
||||||
|
if upsert:
|
||||||
|
self.db.execute_sql(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
||||||
|
VALUES(?, ?)
|
||||||
|
""",
|
||||||
|
(event_id, serialize(embedding)),
|
||||||
|
)
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]:
|
def batch_embed_thumbnail(
|
||||||
images = [
|
self, event_thumbs: dict[str, bytes], upsert: bool = True
|
||||||
Image.open(io.BytesIO(thumb)).convert("RGB")
|
) -> list[ndarray]:
|
||||||
for thumb in event_thumbs.values()
|
"""Embed thumbnails and optionally insert into DB.
|
||||||
]
|
|
||||||
|
@param: event_thumbs Map of Event IDs in DB to thumbnail bytes in jpg format
|
||||||
|
@param: upsert If embedding should be upserted into vec DB
|
||||||
|
"""
|
||||||
ids = list(event_thumbs.keys())
|
ids = list(event_thumbs.keys())
|
||||||
embeddings = self.vision_embedding(images)
|
embeddings = self.vision_embedding(list(event_thumbs.values()))
|
||||||
|
|
||||||
items = []
|
if upsert:
|
||||||
|
items = []
|
||||||
|
|
||||||
for i in range(len(ids)):
|
for i in range(len(ids)):
|
||||||
items.append(ids[i])
|
items.append(ids[i])
|
||||||
items.append(serialize(embeddings[i]))
|
items.append(serialize(embeddings[i]))
|
||||||
|
|
||||||
|
self.db.execute_sql(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
||||||
|
VALUES {}
|
||||||
|
""".format(", ".join(["(?, ?)"] * len(ids))),
|
||||||
|
items,
|
||||||
|
)
|
||||||
|
|
||||||
self.db.execute_sql(
|
|
||||||
"""
|
|
||||||
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
|
||||||
VALUES {}
|
|
||||||
""".format(", ".join(["(?, ?)"] * len(ids))),
|
|
||||||
items,
|
|
||||||
)
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def upsert_description(self, event_id: str, description: str) -> ndarray:
|
def embed_description(
|
||||||
|
self, event_id: str, description: str, upsert: bool = True
|
||||||
|
) -> ndarray:
|
||||||
embedding = self.text_embedding([description])[0]
|
embedding = self.text_embedding([description])[0]
|
||||||
self.db.execute_sql(
|
|
||||||
"""
|
if upsert:
|
||||||
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
self.db.execute_sql(
|
||||||
VALUES(?, ?)
|
"""
|
||||||
""",
|
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||||
(event_id, serialize(embedding)),
|
VALUES(?, ?)
|
||||||
)
|
""",
|
||||||
|
(event_id, serialize(embedding)),
|
||||||
|
)
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray:
|
def batch_embed_description(
|
||||||
|
self, event_descriptions: dict[str, str], upsert: bool = True
|
||||||
|
) -> ndarray:
|
||||||
# upsert embeddings one by one to avoid token limit
|
# upsert embeddings one by one to avoid token limit
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
||||||
for desc in event_descriptions.values():
|
for desc in event_descriptions.values():
|
||||||
embeddings.append(self.text_embedding([desc])[0])
|
embeddings.append(self.text_embedding([desc])[0])
|
||||||
|
|
||||||
ids = list(event_descriptions.keys())
|
if upsert:
|
||||||
|
ids = list(event_descriptions.keys())
|
||||||
|
items = []
|
||||||
|
|
||||||
items = []
|
for i in range(len(ids)):
|
||||||
|
items.append(ids[i])
|
||||||
|
items.append(serialize(embeddings[i]))
|
||||||
|
|
||||||
for i in range(len(ids)):
|
self.db.execute_sql(
|
||||||
items.append(ids[i])
|
"""
|
||||||
items.append(serialize(embeddings[i]))
|
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||||
|
VALUES {}
|
||||||
self.db.execute_sql(
|
""".format(", ".join(["(?, ?)"] * len(ids))),
|
||||||
"""
|
items,
|
||||||
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
)
|
||||||
VALUES {}
|
|
||||||
""".format(", ".join(["(?, ?)"] * len(ids))),
|
|
||||||
items,
|
|
||||||
)
|
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@ -261,10 +278,10 @@ class Embeddings:
|
|||||||
totals["processed_objects"] += 1
|
totals["processed_objects"] += 1
|
||||||
|
|
||||||
# run batch embedding
|
# run batch embedding
|
||||||
self.batch_upsert_thumbnail(batch_thumbs)
|
self.batch_embed_thumbnail(batch_thumbs)
|
||||||
|
|
||||||
if batch_descs:
|
if batch_descs:
|
||||||
self.batch_upsert_description(batch_descs)
|
self.batch_embed_description(batch_descs)
|
||||||
|
|
||||||
# report progress every batch so we don't spam the logs
|
# report progress every batch so we don't spam the logs
|
||||||
progress = (totals["processed_objects"] / total_events) * 100
|
progress = (totals["processed_objects"] / total_events) * 100
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
@ -31,6 +32,12 @@ disable_progress_bar()
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelTypeEnum(str, Enum):
|
||||||
|
face = "face"
|
||||||
|
vision = "vision"
|
||||||
|
text = "text"
|
||||||
|
|
||||||
|
|
||||||
class GenericONNXEmbedding:
|
class GenericONNXEmbedding:
|
||||||
"""Generic embedding function for ONNX models (text and vision)."""
|
"""Generic embedding function for ONNX models (text and vision)."""
|
||||||
|
|
||||||
@ -88,7 +95,10 @@ class GenericONNXEmbedding:
|
|||||||
file_name = os.path.basename(path)
|
file_name = os.path.basename(path)
|
||||||
if file_name in self.download_urls:
|
if file_name in self.download_urls:
|
||||||
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
||||||
elif file_name == self.tokenizer_file and self.model_type == "text":
|
elif (
|
||||||
|
file_name == self.tokenizer_file
|
||||||
|
and self.model_type == ModelTypeEnum.text
|
||||||
|
):
|
||||||
if not os.path.exists(path + "/" + self.model_name):
|
if not os.path.exists(path + "/" + self.model_name):
|
||||||
logger.info(f"Downloading {self.model_name} tokenizer")
|
logger.info(f"Downloading {self.model_name} tokenizer")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
@ -119,7 +129,7 @@ class GenericONNXEmbedding:
|
|||||||
if self.runner is None:
|
if self.runner is None:
|
||||||
if self.downloader:
|
if self.downloader:
|
||||||
self.downloader.wait_for_download()
|
self.downloader.wait_for_download()
|
||||||
if self.model_type == "text":
|
if self.model_type == ModelTypeEnum.text:
|
||||||
self.tokenizer = self._load_tokenizer()
|
self.tokenizer = self._load_tokenizer()
|
||||||
else:
|
else:
|
||||||
self.feature_extractor = self._load_feature_extractor()
|
self.feature_extractor = self._load_feature_extractor()
|
||||||
@ -143,11 +153,35 @@ class GenericONNXEmbedding:
|
|||||||
f"{MODEL_CACHE_DIR}/{self.model_name}",
|
f"{MODEL_CACHE_DIR}/{self.model_name}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _preprocess_inputs(self, raw_inputs: any) -> any:
|
||||||
|
if self.model_type == ModelTypeEnum.text:
|
||||||
|
max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs)
|
||||||
|
return [
|
||||||
|
self.tokenizer(
|
||||||
|
text,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
for text in raw_inputs
|
||||||
|
]
|
||||||
|
elif self.model_type == ModelTypeEnum.vision:
|
||||||
|
processed_images = [self._process_image(img) for img in raw_inputs]
|
||||||
|
return [
|
||||||
|
self.feature_extractor(images=image, return_tensors="np")
|
||||||
|
for image in processed_images
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unable to preprocess inputs for {self.model_type}")
|
||||||
|
|
||||||
def _process_image(self, image):
|
def _process_image(self, image):
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
if image.startswith("http"):
|
if image.startswith("http"):
|
||||||
response = requests.get(image)
|
response = requests.get(image)
|
||||||
image = Image.open(BytesIO(response.content)).convert("RGB")
|
image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||||
|
elif isinstance(image, bytes):
|
||||||
|
image = Image.open(BytesIO(image)).convert("RGB")
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@ -163,25 +197,7 @@ class GenericONNXEmbedding:
|
|||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if self.model_type == "text":
|
processed_inputs = self._preprocess_inputs(inputs)
|
||||||
max_length = max(len(self.tokenizer.encode(text)) for text in inputs)
|
|
||||||
processed_inputs = [
|
|
||||||
self.tokenizer(
|
|
||||||
text,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
max_length=max_length,
|
|
||||||
return_tensors="np",
|
|
||||||
)
|
|
||||||
for text in inputs
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
processed_images = [self._process_image(img) for img in inputs]
|
|
||||||
processed_inputs = [
|
|
||||||
self.feature_extractor(images=image, return_tensors="np")
|
|
||||||
for image in processed_images
|
|
||||||
]
|
|
||||||
|
|
||||||
input_names = self.runner.get_input_names()
|
input_names = self.runner.get_input_names()
|
||||||
onnx_inputs = {name: [] for name in input_names}
|
onnx_inputs = {name: [] for name in input_names}
|
||||||
input: dict[str, any]
|
input: dict[str, any]
|
||||||
|
@ -86,7 +86,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
try:
|
try:
|
||||||
if topic == EmbeddingsRequestEnum.embed_description.value:
|
if topic == EmbeddingsRequestEnum.embed_description.value:
|
||||||
return serialize(
|
return serialize(
|
||||||
self.embeddings.upsert_description(
|
self.embeddings.embed_description(
|
||||||
data["id"], data["description"]
|
data["id"], data["description"]
|
||||||
),
|
),
|
||||||
pack=False,
|
pack=False,
|
||||||
@ -94,7 +94,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
|
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
|
||||||
thumbnail = base64.b64decode(data["thumbnail"])
|
thumbnail = base64.b64decode(data["thumbnail"])
|
||||||
return serialize(
|
return serialize(
|
||||||
self.embeddings.upsert_thumbnail(data["id"], thumbnail),
|
self.embeddings.embed_thumbnail(data["id"], thumbnail),
|
||||||
pack=False,
|
pack=False,
|
||||||
)
|
)
|
||||||
elif topic == EmbeddingsRequestEnum.generate_search.value:
|
elif topic == EmbeddingsRequestEnum.generate_search.value:
|
||||||
@ -270,7 +270,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
|
|
||||||
def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
|
def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
|
||||||
"""Embed the thumbnail for an event."""
|
"""Embed the thumbnail for an event."""
|
||||||
self.embeddings.upsert_thumbnail(event_id, thumbnail)
|
self.embeddings.embed_thumbnail(event_id, thumbnail)
|
||||||
|
|
||||||
def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None:
|
def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None:
|
||||||
"""Embed the description for an event."""
|
"""Embed the description for an event."""
|
||||||
@ -290,8 +290,8 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
{"id": event.id, "description": description},
|
{"id": event.id, "description": description},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Encode the description
|
# Embed the description
|
||||||
self.embeddings.upsert_description(event.id, description)
|
self.embeddings.embed_description(event.id, description)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Generated description for %s (%d images): %s",
|
"Generated description for %s (%d images): %s",
|
||||||
|
@ -279,10 +279,27 @@ def get_intel_gpu_stats() -> dict[str, str]:
|
|||||||
logger.error(f"Unable to poll intel GPU stats: {p.stderr}")
|
logger.error(f"Unable to poll intel GPU stats: {p.stderr}")
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
|
output = "".join(p.stdout.split())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(f'[{"".join(p.stdout.split())}]')
|
data = json.loads(f"[{output}]")
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return {"gpu": "-%", "mem": "-%"}
|
data = None
|
||||||
|
|
||||||
|
# json is incomplete, remove characters until we get to valid json
|
||||||
|
while True:
|
||||||
|
while output and output[-1] != "}":
|
||||||
|
output = output[:-1]
|
||||||
|
|
||||||
|
if not output:
|
||||||
|
return {"gpu": "", "mem": ""}
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(f"[{output}]")
|
||||||
|
break
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
output = output[:-1]
|
||||||
|
continue
|
||||||
|
|
||||||
results: dict[str, str] = {}
|
results: dict[str, str] = {}
|
||||||
render = {"global": []}
|
render = {"global": []}
|
||||||
|
Loading…
Reference in New Issue
Block a user