Embeddings fixes (#14269)

* Add debugging logs for more info

* Improve timeout handling

* Fix event cleanup

* Handle zmq error and empty data

* Don't run download

* Remove unneeded embeddings creations

* Update timouts

* Init models immediately

* Fix order of init

* Cleanup
This commit is contained in:
Nicolas Mowen 2024-10-10 15:37:43 -06:00 committed by GitHub
parent f67ec241d4
commit dd6276e706
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 154 additions and 111 deletions

View File

@ -581,12 +581,12 @@ class FrigateApp:
self.init_recording_manager()
self.init_review_segment_manager()
self.init_go2rtc()
self.start_detectors()
self.init_embeddings_manager()
self.bind_database()
self.check_db_data_migrations()
self.init_inter_process_communicator()
self.init_dispatcher()
self.start_detectors()
self.init_embeddings_manager()
self.init_embeddings_client()
self.start_video_output_processor()
self.start_ptz_autotracker()

View File

@ -64,6 +64,9 @@ class Dispatcher:
self.onvif = onvif
self.ptz_metrics = ptz_metrics
self.comms = communicators
self.camera_activity = {}
self.model_state = {}
self.embeddings_reindex = {}
self._camera_settings_handlers: dict[str, Callable] = {
"audio": self._on_audio_command,
@ -85,10 +88,6 @@ class Dispatcher:
for comm in self.comms:
comm.subscribe(self._receive)
self.camera_activity = {}
self.model_state = {}
self.embeddings_reindex = {}
def _receive(self, topic: str, payload: str) -> Optional[Any]:
"""Handle receiving of payload from communicators."""

View File

@ -22,7 +22,7 @@ class EmbeddingsResponder:
def check_for_request(self, process: Callable) -> None:
while True: # load all messages that are queued
has_message, _, _ = zmq.select([self.socket], [], [], 1)
has_message, _, _ = zmq.select([self.socket], [], [], 0.1)
if not has_message:
break
@ -54,8 +54,11 @@ class EmbeddingsRequestor:
def send_data(self, topic: str, data: any) -> str:
"""Sends data and then waits for reply."""
self.socket.send_json((topic, data))
return self.socket.recv_json()
try:
self.socket.send_json((topic, data))
return self.socket.recv_json()
except zmq.ZMQError:
return ""
def stop(self) -> None:
self.socket.close()

View File

@ -39,7 +39,7 @@ class EventMetadataSubscriber(Subscriber):
super().__init__(topic)
def check_for_update(
self, timeout: float = None
self, timeout: float = 1
) -> Optional[tuple[EventMetadataTypeEnum, str, RegenerateDescriptionEnum]]:
return super().check_for_update(timeout)

View File

@ -28,3 +28,26 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
def delete_embeddings_description(self, event_ids: list[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids)
def drop_embeddings_tables(self) -> None:
self.execute_sql("""
DROP TABLE vec_descriptions;
""")
self.execute_sql("""
DROP TABLE vec_thumbnails;
""")
def create_embeddings_tables(self) -> None:
"""Create vec0 virtual table for embeddings"""
self.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY,
thumbnail_embedding FLOAT[768]
);
""")
self.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY,
description_embedding FLOAT[768]
);
""")

View File

@ -19,7 +19,6 @@ from frigate.models import Event
from frigate.util.builtin import serialize
from frigate.util.services import listen
from .embeddings import Embeddings
from .maintainer import EmbeddingMaintainer
from .util import ZScoreNormalization
@ -57,12 +56,6 @@ def manage_embeddings(config: FrigateConfig) -> None:
models = [Event]
db.bind(models)
embeddings = Embeddings(config.semantic_search, db)
# Check if we need to re-index events
if config.semantic_search.reindex:
embeddings.reindex()
maintainer = EmbeddingMaintainer(
db,
config,
@ -114,19 +107,25 @@ class EmbeddingsContext:
query_embedding = row[0]
else:
# If no embedding found, generate it and return it
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.embed_thumbnail.value,
{"id": query.id, "thumbnail": query.thumbnail},
)
data = self.requestor.send_data(
EmbeddingsRequestEnum.embed_thumbnail.value,
{"id": str(query.id), "thumbnail": str(query.thumbnail)},
)
if not data:
return []
query_embedding = serialize(data)
else:
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.generate_search.value, query
)
data = self.requestor.send_data(
EmbeddingsRequestEnum.generate_search.value, query
)
if not data:
return []
query_embedding = serialize(data)
sql_query = """
SELECT
id,
@ -155,12 +154,15 @@ class EmbeddingsContext:
def search_description(
self, query_text: str, event_ids: list[str] = None
) -> list[tuple[str, float]]:
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.generate_search.value, query_text
)
data = self.requestor.send_data(
EmbeddingsRequestEnum.generate_search.value, query_text
)
if not data:
return []
query_embedding = serialize(data)
# Prepare the base SQL query
sql_query = """
SELECT

View File

@ -63,7 +63,7 @@ class Embeddings:
self.requestor = InterProcessRequestor()
# Create tables if they don't exist
self._create_tables()
self.db.create_embeddings_tables()
models = [
"jinaai/jina-clip-v1-text_model_fp16.onnx",
@ -96,6 +96,7 @@ class Embeddings:
},
embedding_function=jina_text_embedding_function,
model_type="text",
requestor=self.requestor,
device="CPU",
)
@ -108,34 +109,10 @@ class Embeddings:
},
embedding_function=jina_vision_embedding_function,
model_type="vision",
requestor=self.requestor,
device=self.config.device,
)
def _create_tables(self):
# Create vec0 virtual table for thumbnail embeddings
self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY,
thumbnail_embedding FLOAT[768]
);
""")
# Create vec0 virtual table for description embeddings
self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY,
description_embedding FLOAT[768]
);
""")
def _drop_tables(self):
self.db.execute_sql("""
DROP TABLE vec_descriptions;
""")
self.db.execute_sql("""
DROP TABLE vec_thumbnails;
""")
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
# Convert thumbnail bytes to PIL Image
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
@ -153,7 +130,6 @@ class Embeddings:
def upsert_description(self, event_id: str, description: str):
embedding = self.text_embedding([description])[0]
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
@ -167,8 +143,10 @@ class Embeddings:
def reindex(self) -> None:
logger.info("Indexing tracked object embeddings...")
self._drop_tables()
self._create_tables()
self.db.drop_embeddings_tables()
logger.debug("Dropped embeddings tables.")
self.db.create_embeddings_tables()
logger.debug("Created embeddings tables.")
st = time.time()
totals = {

View File

@ -15,6 +15,7 @@ from PIL import Image
from transformers import AutoFeatureExtractor, AutoTokenizer
from transformers.utils.logging import disable_progress_bar
from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
@ -41,12 +42,14 @@ class GenericONNXEmbedding:
download_urls: Dict[str, str],
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
model_type: str,
requestor: InterProcessRequestor,
tokenizer_file: Optional[str] = None,
device: str = "AUTO",
):
self.model_name = model_name
self.model_file = model_file
self.tokenizer_file = tokenizer_file
self.requestor = requestor
self.download_urls = download_urls
self.embedding_function = embedding_function
self.model_type = model_type # 'text' or 'vision'
@ -58,15 +61,32 @@ class GenericONNXEmbedding:
self.tokenizer = None
self.feature_extractor = None
self.session = None
self.downloader = ModelDownloader(
model_name=self.model_name,
download_path=self.download_path,
file_names=list(self.download_urls.keys())
+ ([self.tokenizer_file] if self.tokenizer_file else []),
download_func=self._download_model,
files_names = list(self.download_urls.keys()) + (
[self.tokenizer_file] if self.tokenizer_file else []
)
self.downloader.ensure_model_files()
if not all(
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
):
logger.debug(f"starting model download for {self.model_name}")
self.downloader = ModelDownloader(
model_name=self.model_name,
download_path=self.download_path,
file_names=files_names,
requestor=self.requestor,
download_func=self._download_model,
)
self.downloader.ensure_model_files()
else:
self.downloader = None
ModelDownloader.mark_files_state(
self.requestor,
self.model_name,
files_names,
ModelStatusTypesEnum.downloaded,
)
self._load_model_and_tokenizer()
logger.debug(f"models are already downloaded for {self.model_name}")
def _download_model(self, path: str):
try:
@ -102,7 +122,8 @@ class GenericONNXEmbedding:
def _load_model_and_tokenizer(self):
if self.session is None:
self.downloader.wait_for_download()
if self.downloader:
self.downloader.wait_for_download()
if self.model_type == "text":
self.tokenizer = self._load_tokenizer()
else:
@ -125,13 +146,12 @@ class GenericONNXEmbedding:
f"{MODEL_CACHE_DIR}/{self.model_name}",
)
def _load_model(self, path: str):
def _load_model(self, path: str) -> Optional[ort.InferenceSession]:
if os.path.exists(path):
return ort.InferenceSession(
path, providers=self.providers, provider_options=self.provider_options
)
else:
logger.warning(f"{self.model_name} model file {path} not found.")
return None
def _process_image(self, image):

View File

@ -41,10 +41,14 @@ class EmbeddingMaintainer(threading.Thread):
config: FrigateConfig,
stop_event: MpEvent,
) -> None:
threading.Thread.__init__(self)
self.name = "embeddings_maintainer"
super().__init__(name="embeddings_maintainer")
self.config = config
self.embeddings = Embeddings(config.semantic_search, db)
# Check if we need to re-index events
if config.semantic_search.reindex:
self.embeddings.reindex()
self.event_subscriber = EventUpdateSubscriber()
self.event_end_subscriber = EventEndSubscriber()
self.event_metadata_subscriber = EventMetadataSubscriber(
@ -76,26 +80,33 @@ class EmbeddingMaintainer(threading.Thread):
def _process_requests(self) -> None:
"""Process embeddings requests"""
def handle_request(topic: str, data: str) -> str:
if topic == EmbeddingsRequestEnum.embed_description.value:
return serialize(
self.embeddings.upsert_description(data["id"], data["description"]),
pack=False,
)
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
thumbnail = base64.b64decode(data["thumbnail"])
return serialize(
self.embeddings.upsert_thumbnail(data["id"], thumbnail),
pack=False,
)
elif topic == EmbeddingsRequestEnum.generate_search.value:
return serialize(self.embeddings.text_embedding([data])[0], pack=False)
def _handle_request(topic: str, data: str) -> str:
try:
if topic == EmbeddingsRequestEnum.embed_description.value:
return serialize(
self.embeddings.upsert_description(
data["id"], data["description"]
),
pack=False,
)
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
thumbnail = base64.b64decode(data["thumbnail"])
return serialize(
self.embeddings.upsert_thumbnail(data["id"], thumbnail),
pack=False,
)
elif topic == EmbeddingsRequestEnum.generate_search.value:
return serialize(
self.embeddings.text_embedding([data])[0], pack=False
)
except Exception as e:
logger.error(f"Unable to handle embeddings request {e}")
self.embeddings_responder.check_for_request(handle_request)
self.embeddings_responder.check_for_request(_handle_request)
def _process_updates(self) -> None:
"""Process event updates"""
update = self.event_subscriber.check_for_update()
update = self.event_subscriber.check_for_update(timeout=0.1)
if update is None:
return
@ -124,7 +135,7 @@ class EmbeddingMaintainer(threading.Thread):
def _process_finalized(self) -> None:
"""Process the end of an event."""
while True:
ended = self.event_end_subscriber.check_for_update()
ended = self.event_end_subscriber.check_for_update(timeout=0.1)
if ended == None:
break
@ -161,9 +172,6 @@ class EmbeddingMaintainer(threading.Thread):
or set(event.zones) & set(camera_config.genai.required_zones)
)
):
logger.debug(
f"Description generation for {event}, has_snapshot: {event.has_snapshot}"
)
if event.has_snapshot and camera_config.genai.use_snapshot:
with open(
os.path.join(CLIPS_DIR, f"{event.camera}-{event.id}.jpg"),
@ -217,7 +225,7 @@ class EmbeddingMaintainer(threading.Thread):
def _process_event_metadata(self):
# Check for regenerate description requests
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update(
timeout=1
timeout=0.1
)
if topic is None:

View File

@ -8,11 +8,9 @@ from enum import Enum
from multiprocessing.synchronize import Event as MpEvent
from pathlib import Path
from playhouse.sqliteq import SqliteQueueDatabase
from frigate.config import FrigateConfig
from frigate.const import CLIPS_DIR
from frigate.embeddings.embeddings import Embeddings
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event, Timeline
logger = logging.getLogger(__name__)
@ -25,7 +23,7 @@ class EventCleanupType(str, Enum):
class EventCleanup(threading.Thread):
def __init__(
self, config: FrigateConfig, stop_event: MpEvent, db: SqliteQueueDatabase
self, config: FrigateConfig, stop_event: MpEvent, db: SqliteVecQueueDatabase
):
super().__init__(name="event_cleanup")
self.config = config
@ -35,9 +33,6 @@ class EventCleanup(threading.Thread):
self.removed_camera_labels: list[str] = None
self.camera_labels: dict[str, dict[str, any]] = {}
if self.config.semantic_search.enabled:
self.embeddings = Embeddings(self.config.semantic_search, self.db)
def get_removed_camera_labels(self) -> list[Event]:
"""Get a list of distinct labels for removed cameras."""
if self.removed_camera_labels is None:
@ -234,8 +229,8 @@ class EventCleanup(threading.Thread):
Event.delete().where(Event.id << chunk).execute()
if self.config.semantic_search.enabled:
self.embeddings.delete_description(chunk)
self.embeddings.delete_thumbnail(chunk)
self.db.delete_embeddings_description(chunk)
self.db.delete_embeddings_thumbnail(chunk)
logger.debug(f"Deleted {len(events_to_delete)} embeddings")
logger.info("Exiting event cleanup...")

View File

@ -44,6 +44,7 @@ class ModelDownloader:
download_path: str,
file_names: List[str],
download_func: Callable[[str], None],
requestor: InterProcessRequestor,
silent: bool = False,
):
self.model_name = model_name
@ -51,19 +52,17 @@ class ModelDownloader:
self.file_names = file_names
self.download_func = download_func
self.silent = silent
self.requestor = InterProcessRequestor()
self.requestor = requestor
self.download_thread = None
self.download_complete = threading.Event()
def ensure_model_files(self):
for file in self.file_names:
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{file}",
"state": ModelStatusTypesEnum.downloading,
},
)
self.mark_files_state(
self.requestor,
self.model_name,
self.file_names,
ModelStatusTypesEnum.downloading,
)
self.download_thread = threading.Thread(
target=self._download_models,
name=f"_download_model_{self.model_name}",
@ -119,5 +118,21 @@ class ModelDownloader:
if not silent:
logger.info(f"Downloading complete: {url}")
@staticmethod
def mark_files_state(
requestor: InterProcessRequestor,
model_name: str,
files: list[str],
state: ModelStatusTypesEnum,
) -> None:
for file_name in files:
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{model_name}-{file_name}",
"state": state,
},
)
def wait_for_download(self):
self.download_complete.wait()